diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..56d25c9 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +include README.md +include requirements*.txt +include docs/*.rst +include docs/img/*.png diff --git a/README.md b/README.md index 3615b09..0ecc1e8 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,177 @@ # numpy-ml Ever wish you had an inefficient but somewhat legible collection of machine -learning algorithms implemented exclusively in numpy? No? +learning algorithms implemented exclusively in NumPy? No? + +## Installation + +### For rapid experimentation +To use this code as a starting point for ML prototyping / experimentation, just clone the repository, create a new [virtualenv](https://pypi.org/project/virtualenv/), and start hacking: + +```sh +$ git clone https://github.com/ddbourgin/numpy-ml.git +$ cd numpy-ml && virtualenv npml && source npml/bin/activate +$ pip3 install -r requirements-dev.txt +``` + +### As a package +If you don't plan to modify the source, you can also install numpy-ml as a +Python package: `pip3 install -u numpy_ml`. + +The reinforcement learning agents train on environments defined in the [OpenAI +gym](https://github.com/openai/gym). To install these alongside numpy-ml, you +can use `pip3 install -u 'numpy_ml[rl]'`. ## Documentation -To see all of the available models, take a look at the [project documentation](https://numpy-ml.readthedocs.io/) or see [here](https://github.com/ddbourgin/numpy-ml/blob/master/numpy_ml/README.md). +For more details on the available models, see the [project documentation](https://numpy-ml.readthedocs.io/). + +## Available models +1. **Gaussian mixture model** + - EM training + +2. **Hidden Markov model** + - Viterbi decoding + - Likelihood computation + - MLE parameter estimation via Baum-Welch/forward-backward algorithm + +3. **Latent Dirichlet allocation** (topic model) + - Standard model with MLE parameter estimation via variational EM + - Smoothed model with MAP parameter estimation via MCMC + +4. **Neural networks** + * Layers / Layer-wise ops + - Add + - Flatten + - Multiply + - Softmax + - Fully-connected/Dense + - Sparse evolutionary connections + - LSTM + - Elman-style RNN + - Max + average pooling + - Dot-product attention + - Embedding layer + - Restricted Boltzmann machine (w. CD-n training) + - 2D deconvolution (w. padding and stride) + - 2D convolution (w. padding, dilation, and stride) + - 1D convolution (w. padding, dilation, stride, and causality) + * Modules + - Bidirectional LSTM + - ResNet-style residual blocks (identity and convolution) + - WaveNet-style residual blocks with dilated causal convolutions + - Transformer-style multi-headed scaled dot product attention + * Regularizers + - Dropout + * Normalization + - Batch normalization (spatial and temporal) + - Layer normalization (spatial and temporal) + * Optimizers + - SGD w/ momentum + - AdaGrad + - RMSProp + - Adam + * Learning Rate Schedulers + - Constant + - Exponential + - Noam/Transformer + - Dlib scheduler + * Weight Initializers + - Glorot/Xavier uniform and normal + - He/Kaiming uniform and normal + - Standard and truncated normal + * Losses + - Cross entropy + - Squared error + - Bernoulli VAE loss + - Wasserstein loss with gradient penalty + - Noise contrastive estimation loss + * Activations + - ReLU + - Tanh + - Affine + - Sigmoid + - Leaky ReLU + - ELU + - SELU + - Exponential + - Hard Sigmoid + - Softplus + * Models + - Bernoulli variational autoencoder + - Wasserstein GAN with gradient penalty + - word2vec encoder with skip-gram and CBOW architectures + * Utilities + - `col2im` (MATLAB port) + - `im2col` (MATLAB port) + - `conv1D` + - `conv2D` + - `deconv2D` + - `minibatch` + +5. **Tree-based models** + - Decision trees (CART) + - [Bagging] Random forests + - [Boosting] Gradient-boosted decision trees + +6. **Linear models** + - Ridge regression + - Logistic regression + - Ordinary least squares + - Bayesian linear regression w/ conjugate priors + - Unknown mean, known variance (Gaussian prior) + - Unknown mean, unknown variance (Normal-Gamma / Normal-Inverse-Wishart prior) + +7. **n-Gram sequence models** + - Maximum likelihood scores + - Additive/Lidstone smoothing + - Simple Good-Turing smoothing + +8. **Multi-armed bandit models** + - UCB1 + - LinUCB + - Epsilon-greedy + - Thompson sampling w/ conjugate priors + - Beta-Bernoulli sampler + - LinUCB + +8. **Reinforcement learning models** + - Cross-entropy method agent + - First visit on-policy Monte Carlo agent + - Weighted incremental importance sampling Monte Carlo agent + - Expected SARSA agent + - TD-0 Q-learning agent + - Dyna-Q / Dyna-Q+ with prioritized sweeping + +9. **Nonparameteric models** + - Nadaraya-Watson kernel regression + - k-Nearest neighbors classification and regression + - Gaussian process regression + +10. **Matrix factorization** + - Regularized alternating least-squares + - Non-negative matrix factorization + +11. **Preprocessing** + - Discrete Fourier transform (1D signals) + - Discrete cosine transform (type-II) (1D signals) + - Bilinear interpolation (2D signals) + - Nearest neighbor interpolation (1D and 2D signals) + - Autocorrelation (1D signals) + - Signal windowing + - Text tokenization + - Feature hashing + - Feature standardization + - One-hot encoding / decoding + - Huffman coding / decoding + - Term frequency-inverse document frequency (TF-IDF) encoding + - MFCC encoding + +12. **Utilities** + - Similarity kernels + - Distance metrics + - Priority queue + - Ball tree + - Discrete sampler + - Graph processing and generators ## Contributing diff --git a/numpy_ml/bandits/bandits.py b/numpy_ml/bandits/bandits.py index d66fbaa..e5a14c3 100644 --- a/numpy_ml/bandits/bandits.py +++ b/numpy_ml/bandits/bandits.py @@ -4,7 +4,7 @@ import numpy as np -from ..utils.testing import random_one_hot_matrix, is_number +from numpy_ml.utils.testing import random_one_hot_matrix, is_number class Bandit(ABC): @@ -104,6 +104,7 @@ def __init__(self, payoffs, payoff_probs): self.payoff_probs = payoff_probs self.arm_evs = np.array([sum(p * v) for p, v in zip(payoff_probs, payoffs)]) self.best_ev = np.max(self.arm_evs) + self.best_arm = np.argmax(self.arm_evs) @property def hyperparameters(self): @@ -127,8 +128,10 @@ def oracle_payoff(self, context=None): ------- optimal_rwd : float The expected reward under an optimal policy. + optimal_arm : float + The arm ID with the largest expected reward. """ - return self.best_ev + return self.best_ev, self.best_arm def _pull(self, arm_id, context): payoffs = self.payoffs[arm_id] @@ -159,6 +162,7 @@ def __init__(self, payoff_probs): self.arm_evs = self.payoff_probs self.best_ev = np.max(self.arm_evs) + self.best_arm = np.argmax(self.arm_evs) @property def hyperparameters(self): @@ -181,8 +185,10 @@ def oracle_payoff(self, context=None): ------- optimal_rwd : float The expected reward under an optimal policy. + optimal_arm : float + The arm ID with the largest expected reward. """ - return self.best_ev + return self.best_ev, self.best_arm def _pull(self, arm_id, context): return int(np.random.rand() <= self.payoff_probs[arm_id]) @@ -217,6 +223,7 @@ def __init__(self, payoff_dists, payoff_probs): self.payoff_probs = payoff_probs self.arm_evs = np.array([mu for (mu, var) in payoff_dists]) self.best_ev = np.max(self.arm_evs) + self.best_arm = np.argmax(self.arm_evs) @property def hyperparameters(self): @@ -249,8 +256,10 @@ def oracle_payoff(self, context=None): ------- optimal_rwd : float The expected reward under an optimal policy. + optimal_arm : float + The arm ID with the largest expected reward. """ - return self.best_ev + return self.best_ev, self.best_arm class ShortestPathBandit(Bandit): @@ -282,6 +291,7 @@ def __init__(self, G, start_vertex, end_vertex): self.arm_evs = self._calc_arm_evs() self.best_ev = np.max(self.arm_evs) + self.best_arm = np.argmax(self.arm_evs) placeholder = [None] * len(self.paths) super().__init__(placeholder, placeholder) @@ -309,8 +319,10 @@ def oracle_payoff(self, context=None): ------- optimal_rwd : float The expected reward under an optimal policy. + optimal_arm : float + The arm ID with the largest expected reward. """ - return self.best_ev + return self.best_ev, self.best_arm def _calc_arm_evs(self): I2V = self.G.get_vertex @@ -353,7 +365,8 @@ def __init__(self, context_probs): self.context_probs = context_probs self.arm_evs = self.context_probs - self.best_ev = self.arm_evs.max(axis=1) + self.best_evs = self.arm_evs.max(axis=1) + self.best_arms = self.arm_evs.argmax(axis=1) @property def hyperparameters(self): @@ -386,15 +399,17 @@ def oracle_payoff(self, context): Parameters ---------- context : :py:class:`ndarray ` of shape `(D, K)` or None - The current context matrix for each of the bandit arms, if - applicable. Default is None. + The current context matrix for each of the bandit arms. Returns ------- optimal_rwd : float The expected reward under an optimal policy. + optimal_arm : float + The arm ID with the largest expected reward. """ - return context[:, 0] @ self.best_ev + context_id = context[:, 0].argmax() + return self.best_evs[context_id], self.best_arms[context_id] def _pull(self, arm_id, context): D, K = self.context_probs.shape @@ -499,9 +514,11 @@ def oracle_payoff(self, context): ------- optimal_rwd : float The expected reward under an optimal policy. + optimal_arm : float + The arm ID with the largest expected reward. """ best_arm = np.argmax(self.arm_evs) - return self.arm_evs[best_arm] + return self.arm_evs[best_arm], best_arm def _pull(self, arm_id, context): K, thetas = self.K, self.thetas diff --git a/numpy_ml/bandits/policies.py b/numpy_ml/bandits/policies.py index a3f3bb6..6d4c4b9 100644 --- a/numpy_ml/bandits/policies.py +++ b/numpy_ml/bandits/policies.py @@ -202,13 +202,12 @@ def __init__(self, C=1, ev_prior=0.5): \text{UCB}(a, t) = \text{EV}_t(a) + C \sqrt{\frac{2 \log t}{N_t(a)}} - where :math:`\text{UCB}(a, t)` is the upper confidence bound on the - expected value of arm `a` at time `t`, :math:`\text{EV}_t(a)` is the - average of the rewards recieved so far from pulling arm `a`, `C` is a - parameter controlling the confidence upper bound of the estimate for - :math:`\text{UCB}(a, t)` (for logarithmic regret bounds, `C` must - equal 1), and :math:`N_t(a)` is the number of times arm `a` has been - pulled during the previous `t - 1` timesteps. + where :math:`\text{EV}_t(a)` is the average of the rewards recieved so + far from pulling arm `a`, `C` is a free parameter controlling the + "optimism" of the confidence upper bound for :math:`\text{UCB}(a, t)` + (for logarithmic regret bounds, `C` must equal 1), and :math:`N_t(a)` + is the number of times arm `a` has been pulled during the previous `t - + 1` timesteps. References ---------- @@ -220,7 +219,8 @@ def __init__(self, C=1, ev_prior=0.5): ---------- C : float in (0, +infinity) A confidence/optimisim parameter affecting the degree of - exploration. The UCB1 algorithm assumes `C=1`. Default is 1. + exploration, where larger values encourage greater exploration. The + UCB1 algorithm assumes `C=1`. Default is 1. ev_prior : float The starting expected value for each arm before any data has been observed. Default is 0.5. @@ -292,10 +292,10 @@ def __init__(self, alpha=1, beta=1): where :math:`k \in \{1,\ldots,K \}` indexes arms in the MAB and :math:`\theta_k` is the parameter of the Bernoulli likelihood for arm `k`. The sampler begins by selecting an arm with probability - proportional to it's payoff probability under the initial Beta prior. + proportional to its payoff probability under the initial Beta prior. After pulling the sampled arm and receiving a reward, `r`, the sampler computes the posterior over the model parameters (arm payoffs) via - Bayes' rule, and then samples a new action in proportion to it's payoff + Bayes' rule, and then samples a new action in proportion to its payoff probability under this posterior. This process (i.e., sample action from posterior, take action and receive reward, compute updated posterior) is repeated until the number of trials is exhausted. diff --git a/numpy_ml/bandits/trainer.py b/numpy_ml/bandits/trainer.py index 79211e0..f925d32 100644 --- a/numpy_ml/bandits/trainer.py +++ b/numpy_ml/bandits/trainer.py @@ -1,16 +1,20 @@ """A trainer/runner object for executing and comparing MAB policies.""" +import warnings import os.path as op from collections import defaultdict import numpy as np +from numpy_ml.utils.testing import DependencyWarning + try: import matplotlib.pyplot as plt _PLOTTING = True except ImportError: - print("Cannot import matplotlib. Plotting functionality disabled.") + fstr = "Cannot import matplotlib. Plotting functionality disabled." + warnings.warn(fstr, DependencyWarning) _PLOTTING = False @@ -84,11 +88,12 @@ def compare( self, policies, bandit, - ep_length, - n_episodes, + n_trials, n_duplicates, - seed=12345, + plot=True, + seed=None, smooth_weight=0.999, + out_dir=None, ): """ Compare the performance of multiple policies on the same bandit @@ -100,38 +105,49 @@ def compare( The multi-armed bandit policies to compare. bandit : :class:`Bandit ` instance The environment to train the policies on. - ep_length : int - The number of pulls allowed in each episode - n_episodes : int - The number of episodes per run + n_trials : int + The number of trials per run. n_duplicates: int - The number of runs to evaluate + The number of times to evaluate each policy on the bandit + environment. Larger values permit a better estimate of the + variance in payoff / cumulative regret for each policy. + plot : bool + Whether to generate a plot of the policy's average reward and + regret across the episodes. Default is True. seed : int - The seed for the random number generator. Default is 12345. + The seed for the random number generator. Default is None. smooth_weight : float in [0, 1] The smoothing weight. Values closer to 0 result in less smoothing, values closer to 1 produce more aggressive smoothing. Default is 0.999. + out_dir : str or None + Plots will be saved to this directory if `plot` is True. If + `out_dir` is None, plots will not be saved. Default is None. """ # noqa: E501 self.init_logs(policies) - fig, all_axes = plt.subplots(len(policies), 2, sharex=True) - fig.set_size_inches(10.5, len(policies) * 5.25) + + all_axes = [None] * len(policies) + if plot and _PLOTTING: + fig, all_axes = plt.subplots(len(policies), 2, sharex=True) + fig.set_size_inches(10.5, len(policies) * 5.25) for policy, axes in zip(policies, all_axes): - np.random.seed(seed) + if seed: + np.random.seed(seed) + bandit.reset() policy.reset() self.train( policy, bandit, - ep_length, - n_episodes, + n_trials, n_duplicates, axes=axes, - plot=True, + plot=plot, verbose=False, - smooth_weight=0.999, + out_dir=out_dir, + smooth_weight=smooth_weight, ) # enforce the same y-ranges across plots for straightforward comparison @@ -146,24 +162,23 @@ def compare( a1.set_ylim(a1_min, a1_max) a2.set_ylim(a2_min, a2_max) - sdir = get_scriptdir() - plt.savefig("{}/img/{}.png".format(sdir, "comparison"), dpi=300) - - plt.show() - plt.close("all") + if plot and _PLOTTING: + if out_dir is not None: + plt.savefig(op.join(out_dir, "bandit_comparison.png"), dpi=300) + plt.show() def train( self, policy, bandit, - ep_length, - n_episodes, + n_trials, n_duplicates, plot=True, axes=None, verbose=True, print_every=100, smooth_weight=0.999, + out_dir=None, ): """ Train a MAB policies on a multi-armed bandit problem, logging training @@ -175,10 +190,8 @@ def train( The multi-armed bandit policy to train. bandit : :class:`Bandit ` instance The environment to run the policy on. - ep_length : int - The number of pulls allowed in each episode - n_episodes : int - The number of episodes per run + n_trials : int + The number of trials per run. n_duplicates: int The number of runs to evaluate plot : bool @@ -197,6 +210,9 @@ def train( The smoothing weight. Values closer to 0 result in less smoothing, values closer to 1 produce more aggressive smoothing. Default is 0.999. + out_dir : str or None + Plots will be saved to this directory if `plot` is True. If + `out_dir` is None, plots will not be saved. Default is None. Returns ------- @@ -218,33 +234,34 @@ def train( policy.reset() avg_oracle_reward, cregret = 0, 0 - for e_id in range(n_episodes): - oracle_reward, ep_reward = 0, 0 - - for s in range(ep_length): - rwd, arm, orwd = self._train_step(bandit, policy) - ep_reward += rwd - oracle_reward += orwd + for trial_id in range(n_trials): + rwd, arm, orwd, oarm = self._train_step(bandit, policy) loss = mse(bandit, policy) - regret = oracle_reward - ep_reward - avg_oracle_reward += oracle_reward / n_episodes + regret = orwd - rwd + + avg_oracle_reward += orwd cregret += regret - L[p]["mse"][e_id + 1].append(loss) - L[p]["regret"][e_id + 1].append(regret) - L[p]["cregret"][e_id + 1].append(cregret) - L[p]["reward"][e_id + 1].append(ep_reward) + L[p]["mse"][trial_id + 1].append(loss) + L[p]["reward"][trial_id + 1].append(rwd) + L[p]["regret"][trial_id + 1].append(regret) + L[p]["cregret"][trial_id + 1].append(cregret) + L[p]["optimal_arm"][trial_id + 1].append(oarm) + L[p]["selected_arm"][trial_id + 1].append(arm) + L[p]["optimal_reward"][trial_id + 1].append(orwd) - if (e_id + 1) % print_every == 0 and verbose: - fstr = "Ep. {}/{}, {}/{}, Regret: {:.4f}" - print(fstr.format(e_id + 1, n_episodes, d + 1, D, regret)) + if (trial_id + 1) % print_every == 0 and verbose: + fstr = "Trial {}/{}, {}/{}, Regret: {:.4f}" + print(fstr.format(trial_id + 1, n_trials, d + 1, D, regret)) + + avg_oracle_reward /= n_trials if verbose: self._print_run_summary(bandit, policy, regret) - if plot: - self._plot_reward(avg_oracle_reward, policy, smooth_weight, axes) + if plot and _PLOTTING: + self._plot_reward(avg_oracle_reward, policy, smooth_weight, axes, out_dir) return policy @@ -252,8 +269,8 @@ def _train_step(self, bandit, policy): P, B = policy, bandit C = B.get_context() if hasattr(B, "get_context") else None rwd, arm = P.act(B, C) - oracle_rwd = B.oracle_payoff(C) - return rwd, arm, oracle_rwd + oracle_rwd, oracle_arm = B.oracle_payoff(C) + return rwd, arm, oracle_rwd, oracle_arm def init_logs(self, policies): """ @@ -261,20 +278,30 @@ def init_logs(self, policies): Notes ----- - In the logs, keys are episode numbers, and values are lists of length - ``n_duplicates`` holding the metric values for each duplicate of that - episode. For example, ``logs['regret'][3][1]`` holds the regret value - accrued on the 2nd duplicate of the 4th episode. + Training logs are represented as a nested set of dictionaries with the + following structure: + + log[model_id][metric][trial_number][duplicate_number] + + For example, ``logs['model1']['regret'][3][1]`` holds the regret value + accrued on the 3rd trial of the 2nd duplicate run for model1. + + Available fields are 'regret', 'cregret' (cumulative regret), 'reward', + 'mse' (mean-squared error between estimated arm EVs and the true EVs), + 'optimal_arm', 'selected_arm', and 'optimal_reward'. """ if not isinstance(policies, list): policies = [policies] self.logs = { str(p): { + "mse": defaultdict(lambda: []), "regret": defaultdict(lambda: []), - "cregret": defaultdict(lambda: []), "reward": defaultdict(lambda: []), - "mse": defaultdict(lambda: []), + "cregret": defaultdict(lambda: []), + "optimal_arm": defaultdict(lambda: []), + "selected_arm": defaultdict(lambda: []), + "optimal_reward": defaultdict(lambda: []), } for p in policies } @@ -293,11 +320,7 @@ def _print_run_summary(self, bandit, policy, regret): fstr = "\nFinal MSE: {:.4f}\nFinal Regret: {:.4f}\n\n" print(fstr.format(np.mean(se), regret)) - def _plot_reward(self, optimal_rwd, policy, smooth_weight, axes=None): - if not _PLOTTING: - print("Cannot import matplotlib. Plotting functionality disabled.") - return - + def _plot_reward(self, optimal_rwd, policy, smooth_weight, axes=None, out_dir=None): L = self.logs[str(policy)] smds = self._smoothed_metrics(policy, optimal_rwd, smooth_weight) @@ -335,12 +358,10 @@ def _plot_reward(self, optimal_rwd, policy, smooth_weight, axes=None): fig.suptitle(str(policy)) fig.tight_layout() - sdir = get_scriptdir() - bid = policy.hyperparameters["id"] - plt.savefig("{}/img/{}.png".format(sdir, bid), dpi=300) - + if out_dir is not None: + bid = policy.hyperparameters["id"] + plt.savefig(op.join(out_dir, f"{bid}.png"), dpi=300) plt.show() - plt.close("all") return ax1, ax2 def _smoothed_metrics(self, policy, optimal_rwd, smooth_weight): @@ -349,6 +370,9 @@ def _smoothed_metrics(self, policy, optimal_rwd, smooth_weight): # pre-allocate smoothed data structure smds = {} for m in L.keys(): + if m == "selections": + continue + smds["sm_{}_avg".format(m)] = np.zeros(len(L["reward"])) smds["sm_{}_avg".format(m)][0] = np.mean(L[m][1]) @@ -358,6 +382,8 @@ def _smoothed_metrics(self, policy, optimal_rwd, smooth_weight): smoothed = {m: L[m][1] for m in L.keys()} for e_id in range(2, len(L["reward"]) + 1): for m in L.keys(): + if m == "selections": + continue prev, cur = smoothed[m], L[m][e_id] smoothed[m] = [smooth(p, c, smooth_weight) for p, c in zip(prev, cur)] smds["sm_{}_avg".format(m)][e_id - 1] = np.mean(smoothed[m]) diff --git a/numpy_ml/hmm/hmm.py b/numpy_ml/hmm/hmm.py index a7e84db..51e8ec7 100644 --- a/numpy_ml/hmm/hmm.py +++ b/numpy_ml/hmm/hmm.py @@ -1,9 +1,11 @@ +"""Hidden Markov model module""" + import numpy as np class MultinomialHMM: def __init__(self, A=None, B=None, pi=None, eps=None): - """ + r""" A simple hidden Markov model with multinomial emission distribution. Parameters @@ -68,10 +70,10 @@ def __init__(self, A=None, B=None, pi=None, eps=None): self.B[self.B == 0] = self.eps # set of training sequences - self.O = None + self.O = None # noqa: E741 # number of sequences in O - self.I = None + self.I = None # noqa: E741 # number of observations in each sequence self.T = None @@ -115,10 +117,10 @@ def generate(self, n_steps, latent_state_types, obs_types): return np.array(states), np.array(emissions) def log_likelihood(self, O): - """ + r""" Given the HMM parameterized by :math:`(A`, B, \pi)` and an observation sequence `O`, compute the marginal likelihood of the observations: - :math:`P(O|A,B,\pi)`, summing over latent states. + :math:`P(O \mid A,B,\pi)`, summing over latent states. Notes ----- @@ -128,7 +130,9 @@ def log_likelihood(self, O): probability under the HMM of being in latent state `i` after seeing the first `j` observations: - .. math:: \mathtt{forward[i,j]} = P(o_1,\ldots,o_j,q_j=i \mid A,B,\pi) + .. math:: + + \mathtt{forward[i,j]} = P(o_1, \ldots, o_j, q_j=i \mid A, B, \pi) Here :math:`q_j = i` indicates that the hidden state at time `j` is of type `i`. @@ -137,12 +141,11 @@ def log_likelihood(self, O): .. math:: - \mathtt{forward[i,j]} &= \sum_{s'=1}^N \mathtt{forward[s',j-1]} - \cdot \mathtt{A[s',i]} \cdot \mathtt{B[i,o_j]} \\ - - &= \sum_{s'=1}^N - P(o_1,\ldots,o_{j-1},q_{j-1}=s' \mid A,B,\pi) - P(q_j=i|q_{j-1}=s') P(o_j \mid q_j=i) + \mathtt{forward[i,j]} + &= \sum_{s'=1}^N \mathtt{forward[s',j-1]} \cdot + \mathtt{A[s',i]} \cdot \mathtt{B[i,o_j]} \\ + &= \sum_{s'=1}^N P(o_1, \ldots, o_{j-1}, q_{j-1}=s' \mid A, B, \pi) + P(q_j=i \mid q_{j-1}=s') P(o_j \mid q_j=i) In words, ``forward[i,j]`` is the weighted sum of the values computed on the previous timestep. The weight on each previous state value is the @@ -160,11 +163,11 @@ def log_likelihood(self, O): The likelihood of the observations `O` under the HMM. """ if O.ndim == 1: - O = O.reshape(1, -1) + O = O.reshape(1, -1) # noqa: E741 - I, T = O.shape + I, T = O.shape # noqa: E741 - if I != 1: + if I != 1: # noqa: E741 raise ValueError("Likelihood only accepts a single sequence") forward = self._forward(O[0]) @@ -172,7 +175,7 @@ def log_likelihood(self, O): return log_likelihood def decode(self, O): - """ + r""" Given the HMM parameterized by :math:`(A, B, \pi)` and an observation sequence :math:`O = o_1, \ldots, o_T`, compute the most probable sequence of latent states, :math:`Q = q_1, \ldots, q_T`. @@ -187,7 +190,8 @@ def decode(self, O): .. math:: \mathtt{viterbi[i,j]} = - \max_{q_1,\ldots,q_{j-1}} P(o_1,\ldots,o_j,q_1,\ldots,q_{j-1},q_j=i \mid A,B,\pi) + \max_{q_1, \ldots, q_{j-1}} + P(o_1, \ldots, o_j, q_1, \ldots, q_{j-1}, q_j=i \mid A, B, \pi) Here :math:`q_j = i` indicates that the hidden state at time `j` is of type `i`, and :math:`\max_{q_1,\ldots,q_{j-1}}` represents the maximum over @@ -197,12 +201,12 @@ def decode(self, O): .. math:: - \mathtt{viterbi[i,j]} &= \max_{s'=1}^N \mathtt{viterbi[s',j-1]} \cdot - \mathtt{A[s',i]} \cdot \mathtt{B[i,o_j]} \\ - - &= \max_{s'=1}^N - P(o_1,\ldots,o_j,q_1,\ldots,q_{j-1},q_j=i \mid A,B,\pi) - P(q_j=i \mid q_{j-1}=s') P(o_j \mid q_j=i) + \mathtt{viterbi[i,j]} &= + \max_{s'=1}^N \mathtt{viterbi[s',j-1]} \cdot + \mathtt{A[s',i]} \cdot \mathtt{B[i,o_j]} \\ + &= \max_{s'=1}^N + P(o_1,\ldots, o_j, q_1, \ldots, q_{j-1}, q_j=i \mid A, B, \pi) + P(q_j=i \mid q_{j-1}=s') P(o_j \mid q_j=i) In words, ``viterbi[i,j]`` is the weighted sum of the values computed on the previous timestep. The weight on each value is the product of @@ -235,14 +239,14 @@ def decode(self, O): eps = self.eps if O.ndim == 1: - O = O.reshape(1, -1) + O = O.reshape(1, -1) # noqa: E741 # number of observations in each sequence T = O.shape[1] # number of training sequences - I = O.shape[0] - if I != 1: + I = O.shape[0] # noqa: E741 + if I != 1: # noqa: E741 raise ValueError("Can only decode a single sequence (O.shape[0] must be 1)") # initialize the viterbi and back_pointer matrices @@ -280,7 +284,7 @@ def decode(self, O): return best_path, best_path_log_prob def _forward(self, Obs): - """ + r""" Computes the forward probability trellis for an HMM parameterized by :math:`(A, B, \pi)`. @@ -291,16 +295,22 @@ def _forward(self, Obs): under the HMM of being in latent state `i` after seeing the first `j` observations: - .. math:: \mathtt{forward[i,j]} = P(o_1,\ldots,o_j,q_j=i|A,B,\pi) + .. math:: + + \mathtt{forward[i,j]} = + P(o_1, \ldots, o_j, q_j=i \mid A, B, \pi) Here :math:`q_j = i` indicates that the hidden state at time `j` is of type `i`. The DP step is:: - forward[i,j] = sum_{s'=1}^N forward[s',j-1] * A[s',i] * B[i,o_j] - = sum_{s'=1}^N P(o_1,\ldots,o_{j-1},q_{j-1}=s'|A,B,pi) * - P(q_j=i|q_{j-1}=s') * P(o_j|q_j=i) + .. math:: + + forward[i,j] &= + \sum_{s'=1}^N forward[s',j-1] \times A[s',i] \times B[i,o_j] \\ + &= \sum_{s'=1}^N P(o_1, \ldots, o_{j-1}, q_{j-1}=s' \mid A, B, \pi) + \times P(q_j=i \mid q_{j-1}=s') \times P(o_j \mid q_j=i) In words, ``forward[i,j]`` is the weighted sum of the values computed on the previous timestep. The weight on each previous state value is @@ -336,12 +346,12 @@ def _forward(self, Obs): + np.log(self.A[s_, s] + eps) + np.log(self.B[s, ot] + eps) for s_ in range(self.N) - ] + ] # noqa: C812 ) return forward def _backward(self, Obs): - """ + r""" Compute the backward probability trellis for an HMM parameterized by :math:`(A, B, \pi)`. @@ -352,15 +362,18 @@ def _backward(self, Obs): of seeing the observations from time `j+1` onward given that the HMM is in state `i` at time `j` - .. math:: \mathtt{backward[i,j]} = P(o_{j+1},o_{j+2},...,o_T|q_j=i,A,B,\pi) + .. math:: + + \mathtt{backward[i,j]} = P(o_{j+1},o_{j+2},\ldots,o_T \mid q_j=i,A,B,\pi) Here :math:`q_j = i` indicates that the hidden state at time `j` is of type `i`. The DP step is:: - backward[i,j] = sum_{s'=1}^N backward[s',j+1] * A[i, s'] * B[s',o_{j+1}] - = sum_{s'=1}^N P(o_{j+1},o_{j+2},...,o_T|q_j=i,A,B,pi) * - P(q_{j+1}=s'|q_{j}=i) * P(o_{j+1}|q_{j+1}=s') + backward[i,j] &= + \sum_{s'=1}^N backward[s',j+1] \times A[i, s'] \times B[s',o_{j+1}] \\ + &= \sum_{s'=1}^N P(o_{j+1}, o_{j+2}, \ldots, o_T \mid q_j=i, A, B, pi) + \times P(q_{j+1}=s' \mid q_{j}=i) \times P(o_{j+1} \mid q_{j+1}=s') In words, ``backward[i,j]`` is the weighted sum of the values computed on the following timestep. The weight on each state value from the @@ -396,12 +409,18 @@ def _backward(self, Obs): + np.log(self.B[s_, ot1] + eps) + backward[s_, t + 1] for s_ in range(self.N) - ] + ] # noqa: C812 ) return backward def fit( - self, O, latent_state_types, observation_types, pi=None, tol=1e-5, verbose=False + self, + O, + latent_state_types, + observation_types, + pi=None, + tol=1e-5, + verbose=False, ): """ Given an observation sequence `O` and the set of possible latent states, @@ -446,10 +465,10 @@ def fit( The estimated prior probabilities of each latent state. """ if O.ndim == 1: - O = O.reshape(1, -1) + O = O.reshape(1, -1) # noqa: E741 # observations - self.O = O + self.O = O # noqa: E741 # number of training examples (I) and their lengths (T) self.I, self.T = self.O.shape @@ -492,7 +511,7 @@ def fit( return self.A, self.B, self.pi def _Estep(self): - """ + r""" Run a single E-step update for the Baum-Welch/Forward-Backward algorithm. This step estimates ``xi`` and ``gamma``, the excepted state-state transition counts and the expected state-occupancy counts, @@ -502,17 +521,22 @@ def _Estep(self): and state `j` at time `k+1` given the observed sequence `O` and the current estimates for transition (`A`) and emission (`B`) matrices:: - xi[i,j,k] = P(q_k=i,q_{k+1}=j|O,A,B,pi) - = P(q_k=i,q_{k+1}=j,O|A,B,pi) / P(O|A,B,pi) - = [ - P(o_1,o_2,...,o_k,q_k=i|A,B,pi) * - P(q_{k+1}=j|q_k=i) * P(o_{k+1}|q_{k+1}=j) * - P(o_{k+2},o_{k+3},...,o_T|q_{k+1}=j,A,B,pi) - ] / P(O|A,B,pi) - = [ - fwd[j, k] * self.A[j, i] * - self.B[i, o_{k+1}] * bwd[i, k + 1] - ] / fwd[:, T].sum() + .. math:: + + xi[i,j,k] &= P(q_k=i,q_{k+1}=j \mid O,A,B,pi) \\ + &= \frac{ + P(q_k=i,q_{k+1}=j,O \mid A,B,pi) + }{P(O \mid A,B,pi)} \\ + &= \frac{ + P(o_1,o_2,\ldots,o_k,q_k=i \mid A,B,pi) \times + P(q_{k+1}=j \mid q_k=i) \times + P(o_{k+1} \mid q_{k+1}=j) \times + P(o_{k+2},o_{k+3},\ldots,o_T \mid q_{k+1}=j,A,B,pi) + }{P(O \mid A,B,pi)} \\ + &= \frac{ + \mathtt{fwd[j, k] * self.A[j, i] * + self.B[i, o_{k+1}] * bwd[i, k + 1]} + }{\mathtt{fwd[:, T].sum()}} The expected number of transitions from state `i` to state `j` across the entire sequence is then the sum over all timesteps: ``xi[i,j,:].sum()``. @@ -614,12 +638,12 @@ def _Mstep(self, gamma, xi, phi): for si in range(self.N): for vk in range(self.V): B[si, vk] = logsumexp(count_gamma[:, si, vk]) - logsumexp( - count_gamma[:, si, :] + count_gamma[:, si, :] # noqa: C812 ) for sj in range(self.N): A[si, sj] = logsumexp(count_xi[:, si, sj]) - logsumexp( - count_xi[:, si, :] + count_xi[:, si, :] # noqa: C812 ) np.testing.assert_almost_equal(np.exp(A[si, :]).sum(), 1) diff --git a/numpy_ml/neural_nets/__init__.py b/numpy_ml/neural_nets/__init__.py index a7bcf9e..d8e51ea 100644 --- a/numpy_ml/neural_nets/__init__.py +++ b/numpy_ml/neural_nets/__init__.py @@ -1,3 +1,4 @@ +"""A module of basic building blcoks for constructing neural networks""" from . import utils from . import losses from . import activations @@ -8,4 +9,3 @@ from . import initializers from . import modules from . import models -from . import tests diff --git a/numpy_ml/neural_nets/tests/__init__.py b/numpy_ml/neural_nets/tests/__init__.py deleted file mode 100644 index 73f8158..0000000 --- a/numpy_ml/neural_nets/tests/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -A module of tests for many of the components in the neural_nets package. - -Note that many of the tests in this module rely on external packages like -PyTorch and Tensorflow for gold-standard implementations. -""" - -from .tests import * diff --git a/numpy_ml/ngram/ngram.py b/numpy_ml/ngram/ngram.py index ada15f1..29dd336 100644 --- a/numpy_ml/ngram/ngram.py +++ b/numpy_ml/ngram/ngram.py @@ -1,3 +1,4 @@ +"""A module for different N-gram smoothing models""" import textwrap from abc import ABC, abstractmethod from collections import Counter @@ -5,11 +6,11 @@ import numpy as np from ..linear_models.lm import LinearRegression -from ..preprocessing.nlp import tokenize_words, ngrams +from ..preprocessing.nlp import tokenize_words, ngrams, strip_punctuation class NGramBase(ABC): - def __init__(self, N, unk=True, filter_stopwords=True): + def __init__(self, N, unk=True, filter_stopwords=True, filter_punctuation=True): """ A simple word-level N-gram language model. @@ -23,11 +24,13 @@ def __init__(self, N, unk=True, filter_stopwords=True): self.N = N self.unk = unk self.filter_stopwords = filter_stopwords + self.filter_punctuation = filter_punctuation self.hyperparameters = { "N": N, "unk": unk, "filter_stopwords": filter_stopwords, + "filter_punctuation": filter_punctuation, } super().__init__() @@ -57,20 +60,19 @@ def train(self, corpus_fp, vocab=None, encoding=None): return self._train(corpus_fp, vocab=vocab, encoding=encoding) def _train(self, corpus_fp, vocab=None, encoding=None): - """ - Actual N-gram training logic - """ + """Actual N-gram training logic""" H = self.hyperparameters grams = {N: [] for N in range(1, self.N + 1)} counts = {N: Counter() for N in range(1, self.N + 1)} - filter_stop = H["filter_stopwords"] + filter_stop, filter_punc = H["filter_stopwords"], H["filter_punctuation"] _n_words = 0 - tokens = set([""]) + tokens = {""} bol, eol = [""], [""] with open(corpus_fp, "r", encoding=encoding) as text: for line in text: + line = strip_punctuation(line) if filter_punc else line words = tokenize_words(line, filter_stopwords=filter_stop) if vocab is not None: @@ -174,7 +176,7 @@ def generate(self, N, seed_words=[""], n_sentences=5): return sentences def perplexity(self, words, N): - """ + r""" Calculate the model perplexity on a sequence of words. Notes @@ -183,13 +185,13 @@ def perplexity(self, words, N): .. math:: - PP(W) = \\left( \\frac{1}{p(W)} \\right)^{1 / n} + PP(W) = \left( \frac{1}{p(W)} \right)^{1 / n} or simply .. math:: - PP(W) &= \exp(-\log p(W) / n) \\\\ + PP(W) &= \exp(-\log p(W) / n) \\ &= \exp(H(W)) where :math:`W = [w_1, \ldots, w_k]` is a sequence of words, `H(w)` is @@ -216,7 +218,7 @@ def perplexity(self, words, N): return np.exp(self.cross_entropy(words, N)) def cross_entropy(self, words, N): - """ + r""" Calculate the model cross-entropy on a sequence of words against the empirical distribution of words in a sample. @@ -226,7 +228,7 @@ def cross_entropy(self, words, N): .. math:: - H(W) = -\\frac{\log p(W)}{n} + H(W) = -\frac{\log p(W)}{n} where :math:`W = [w_1, \ldots, w_k]` is a sequence of words, and `n` is the number of `N`-grams in `W`. @@ -251,7 +253,10 @@ def cross_entropy(self, words, N): return -(1 / n_ngrams) * self.log_prob(words, N) def _log_prob(self, words, N): - """Calculate the log probability of a sequence of words under the `N`-gram model""" + """ + Calculate the log probability of a sequence of words under the + `N`-gram model + """ assert N in self.counts, "You do not have counts for {}-grams".format(N) if N > len(words): @@ -293,10 +298,15 @@ def _num_grams_with_count(self, C, N): @abstractmethod def log_prob(self, words, N): + """ + Compute the log probability of a sequence of words under the + unsmoothed, maximum-likelihood `N`-gram language model. + """ raise NotImplementedError @abstractmethod def _log_ngram_prob(self, ngram): + """Return the unsmoothed log probability of the ngram""" raise NotImplementedError @@ -319,6 +329,7 @@ def __init__(self, N, unk=True, filter_stopwords=True, filter_punctuation=True): Whether to remove punctuation before training. Default is True. """ super().__init__(N, unk, filter_stopwords, filter_punctuation) + self.hyperparameters["id"] = "MLENGram" def log_prob(self, words, N): @@ -352,7 +363,7 @@ def _log_ngram_prob(self, ngram): class AdditiveNGram(NGramBase): def __init__( - self, N, K=1, unk=True, filter_stopwords=True, filter_punctuation=True + self, N, K=1, unk=True, filter_stopwords=True, filter_punctuation=True, ): """ An N-Gram model with smoothed probabilities calculated via additive / @@ -384,11 +395,12 @@ def __init__( Whether to remove punctuation before training. Default is True. """ super().__init__(N, unk, filter_stopwords, filter_punctuation) + self.hyperparameters["id"] = "AdditiveNGram" self.hyperparameters["K"] = K def log_prob(self, words, N): - """ + r""" Compute the smoothed log probability of a sequence of words under the `N`-gram language model with additive smoothing. @@ -398,15 +410,15 @@ def log_prob(self, words, N): .. math:: - P(w_i \mid w_{i-1}) = \\frac{A + K}{B + KV} + P(w_i \mid w_{i-1}) = \frac{A + K}{B + KV} where .. math:: - A &= \\text{Count}(w_{i-1}, w_i) \\\\ - B &= \sum_j \\text{Count}(w_{i-1}, w_j) \\\\ - V &= |\{ w_j \ : \ \\text{Count}(w_{i-1}, w_j) > 0 \}| + A &= \text{Count}(w_{i-1}, w_i) \\ + B &= \sum_j \text{Count}(w_{i-1}, w_j) \\ + V &= |\{ w_j \ : \ \text{Count}(w_{i-1}, w_j) > 0 \}| This is equivalent to pretending we've seen every possible `N`-gram sequence at least `K` times. @@ -446,7 +458,7 @@ def _log_ngram_prob(self, ngram): class GoodTuringNGram(NGramBase): def __init__( - self, N, conf=1.96, unk=True, filter_stopwords=True, filter_punctuation=True + self, N, conf=1.96, unk=True, filter_stopwords=True, filter_punctuation=True, ): """ An N-Gram model with smoothed probabilities calculated with the simple @@ -471,6 +483,7 @@ def __init__( Whether to remove punctuation before training. Default is True. """ super().__init__(N, unk, filter_stopwords, filter_punctuation) + self.hyperparameters["id"] = "GoodTuringNGram" self.hyperparameters["conf"] = conf @@ -497,7 +510,7 @@ def train(self, corpus_fp, vocab=None, encoding=None): self._calc_smoothed_counts() def log_prob(self, words, N): - """ + r""" Compute the smoothed log probability of a sequence of words under the `N`-gram language model with Good-Turing smoothing. @@ -507,21 +520,22 @@ def log_prob(self, words, N): .. math:: - P(w_i \mid w_{i-1}) = \\frac{C^*}{\\text{Count}(w_{i-1})} + P(w_i \mid w_{i-1}) = \frac{C^*}{\text{Count}(w_{i-1})} where :math:`C^*` is the Good-Turing smoothed estimate of the bigram count: .. math:: - C^* = \\frac{(c + 1) \\text{NumCounts}(c + 1, 2)}{\\text{NumCounts}(c, 2)} + C^* = \frac{(c + 1) \text{NumCounts}(c + 1, 2)}{\text{NumCounts}(c, 2)} where .. math:: - c &= \\text{Count}(w_{i-1}, w_i) \\\\ - \\text{NumCounts}(r, k) &= |\{ k\\text{-gram} : \\text{Count}(k\\text{-gram}) = r \}| + c &= \text{Count}(w_{i-1}, w_i) \\ + \text{NumCounts}(r, k) &= + |\{ k\text{-gram} : \text{Count}(k\text{-gram}) = r \}| In words, the probability of an `N`-gram that occurs `r` times in the corpus is estimated by dividing up the probability mass occupied by @@ -532,7 +546,7 @@ def log_prob(self, words, N): .. math:: - \log \\text{NumCounts}(r) = b + a \log r + \log \text{NumCounts}(r) = b + a \log r Under the Good-Turing estimator, the total probability assigned to unseen `N`-grams is equal to the relative occurrence of `N`-grams that diff --git a/numpy_ml/nonparametric/knn.py b/numpy_ml/nonparametric/knn.py index bf29089..8825229 100644 --- a/numpy_ml/nonparametric/knn.py +++ b/numpy_ml/nonparametric/knn.py @@ -82,7 +82,10 @@ def predict(self, X): if H["classifier"]: if H["weights"] == "uniform": - pred = Counter(targets).most_common(1)[0][0] + # for consistency with sklearn / scipy.stats.mode, return + # the smallest class ID in the event of a tie + counts = Counter(targets).most_common() + pred, _ = sorted(counts, key=lambda x: (-x[1], x[0]))[0] elif H["weights"] == "distance": best_score = -np.inf for label in set(targets): diff --git a/numpy_ml/bandits/plots.py b/numpy_ml/plots/bandit_plots.py similarity index 94% rename from numpy_ml/bandits/plots.py rename to numpy_ml/plots/bandit_plots.py index 9dd99fc..e39a30f 100644 --- a/numpy_ml/bandits/plots.py +++ b/numpy_ml/plots/bandit_plots.py @@ -4,15 +4,20 @@ import numpy as np -from .bandits import ( +from numpy_ml.bandits import ( MultinomialBandit, BernoulliBandit, ShortestPathBandit, ContextualLinearBandit, ) -from .trainer import BanditTrainer -from .policies import EpsilonGreedy, UCB1, ThompsonSamplingBetaBinomial, LinUCB -from ..utils.graphs import random_DAG, DiGraph, Edge +from numpy_ml.bandits.trainer import BanditTrainer +from numpy_ml.bandits.policies import ( + EpsilonGreedy, + UCB1, + ThompsonSamplingBetaBinomial, + LinUCB, +) +from numpy_ml.utils.graphs import random_DAG, DiGraph, Edge def random_multinomial_mab(n_arms=10, n_choices_per_arm=5, reward_range=[0, 1]): diff --git a/numpy_ml/gmm/plots.py b/numpy_ml/plots/gmm_plots.py similarity index 98% rename from numpy_ml/gmm/plots.py rename to numpy_ml/plots/gmm_plots.py index 00ed952..56f9232 100644 --- a/numpy_ml/gmm/plots.py +++ b/numpy_ml/plots/gmm_plots.py @@ -1,10 +1,9 @@ +# flake8: noqa import numpy as np from sklearn.datasets.samples_generator import make_blobs from scipy.stats import multivariate_normal -import matplotlib -matplotlib.use("TkAgg") import matplotlib.pyplot as plt import seaborn as sns @@ -13,7 +12,7 @@ sns.set_style("white") sns.set_context("paper", font_scale=1) -from .gmm import GMM +from numpy_ml.gmm import GMM from matplotlib.colors import ListedColormap diff --git a/numpy_ml/hmm/tests.py b/numpy_ml/plots/hmm_plots.py similarity index 98% rename from numpy_ml/hmm/tests.py rename to numpy_ml/plots/hmm_plots.py index 35fce5a..7330c55 100644 --- a/numpy_ml/hmm/tests.py +++ b/numpy_ml/plots/hmm_plots.py @@ -1,7 +1,5 @@ +# flake8: noqa import numpy as np -import matplotlib - -matplotlib.use("TkAgg") from matplotlib import pyplot as plt import seaborn as sns @@ -12,7 +10,7 @@ sns.set_context("notebook", font_scale=0.8) from hmmlearn.hmm import MultinomialHMM as MHMM -from .hmm import MultinomialHMM +from numpy_ml.hmm import MultinomialHMM def generate_training_data(params, n_steps=500, n_examples=15): diff --git a/numpy_ml/lda/tests.py b/numpy_ml/plots/lda_plots.py similarity index 97% rename from numpy_ml/lda/tests.py rename to numpy_ml/plots/lda_plots.py index 1ff7b2f..18b584a 100644 --- a/numpy_ml/lda/tests.py +++ b/numpy_ml/plots/lda_plots.py @@ -1,8 +1,5 @@ +# flake8: noqa import numpy as np - -import matplotlib - -matplotlib.use("TkAgg") import matplotlib.pyplot as plt import seaborn as sns @@ -13,7 +10,7 @@ np.random.seed(12345) -from .lda import LDA +from numpy_ml.lda import LDA def generate_corpus(): diff --git a/numpy_ml/linear_models/plots.py b/numpy_ml/plots/lm_plots.py similarity index 99% rename from numpy_ml/linear_models/plots.py rename to numpy_ml/plots/lm_plots.py index 4a91d48..6f9d455 100644 --- a/numpy_ml/linear_models/plots.py +++ b/numpy_ml/plots/lm_plots.py @@ -1,3 +1,4 @@ +# flake8: noqa import numpy as np from sklearn.model_selection import train_test_split @@ -6,9 +7,6 @@ from sklearn.datasets import make_regression from sklearn.metrics import zero_one_loss -import matplotlib - -matplotlib.use("TkAgg") import matplotlib.pyplot as plt import seaborn as sns @@ -19,7 +17,7 @@ sns.set_context("paper", font_scale=0.5) -from .lm import ( +from numpy_ml.linear_models import ( RidgeRegression, LinearRegression, BayesianLinearRegressionKnownVariance, diff --git a/numpy_ml/ngram/plots.py b/numpy_ml/plots/ngram_plots.py similarity index 97% rename from numpy_ml/ngram/plots.py rename to numpy_ml/plots/ngram_plots.py index 0123e5e..27b0c94 100644 --- a/numpy_ml/ngram/plots.py +++ b/numpy_ml/plots/ngram_plots.py @@ -1,3 +1,4 @@ +# flake8: noqa import numpy as np import matplotlib.pyplot as plt @@ -8,7 +9,7 @@ sns.set_style("white") sns.set_context("notebook", font_scale=1) -from .ngram import MLENGram, AdditiveNGram, GoodTuringNGram +from numpy_ml.ngram import MLENGram, AdditiveNGram, GoodTuringNGram def plot_count_models(GT, N): diff --git a/numpy_ml/neural_nets/activations/plots.py b/numpy_ml/plots/nn_activations_plots.py similarity index 95% rename from numpy_ml/neural_nets/activations/plots.py rename to numpy_ml/plots/nn_activations_plots.py index c293f8a..b8db1fb 100644 --- a/numpy_ml/neural_nets/activations/plots.py +++ b/numpy_ml/plots/nn_activations_plots.py @@ -1,7 +1,5 @@ +# flake8: noqa import numpy as np -import matplotlib - -matplotlib.use("TkAgg") import matplotlib.pyplot as plt import seaborn as sns @@ -10,7 +8,7 @@ sns.set_style("white") sns.set_context("notebook", font_scale=0.7) -from .activations import ( +from numpy_ml.neural_nets.activations import ( Affine, ReLU, LeakyReLU, diff --git a/numpy_ml/neural_nets/schedulers/tests.py b/numpy_ml/plots/nn_schedulers_plots.py similarity index 98% rename from numpy_ml/neural_nets/schedulers/tests.py rename to numpy_ml/plots/nn_schedulers_plots.py index b013e81..e18149b 100644 --- a/numpy_ml/neural_nets/schedulers/tests.py +++ b/numpy_ml/plots/nn_schedulers_plots.py @@ -1,8 +1,7 @@ +# flake8: noqa + import time import numpy as np -import matplotlib - -matplotlib.use("TkAgg") import matplotlib.pyplot as plt import seaborn as sns @@ -11,7 +10,7 @@ sns.set_style("white") sns.set_context("notebook", font_scale=0.7) -from .schedulers import ( +from numpy_ml.neural_nets.schedulers import ( ConstantScheduler, ExponentialScheduler, NoamScheduler, diff --git a/numpy_ml/nonparametric/plots.py b/numpy_ml/plots/nonparametric_plots.py similarity index 98% rename from numpy_ml/nonparametric/plots.py rename to numpy_ml/plots/nonparametric_plots.py index d6027ce..5671a1f 100644 --- a/numpy_ml/nonparametric/plots.py +++ b/numpy_ml/plots/nonparametric_plots.py @@ -1,3 +1,4 @@ +# flake8: noqa import numpy as np import matplotlib.pyplot as plt @@ -8,10 +9,8 @@ sns.set_style("white") sns.set_context("paper", font_scale=0.5) -from .gp import GPRegression -from ..linear_models.lm import LinearRegression -from .kernel_regression import KernelRegression -from .knn import KNN +from numpy_ml.nonparametric import GPRegression, KNN, KernelRegression +from numpy_ml.linear_models.lm import LinearRegression from sklearn.model_selection import train_test_split diff --git a/numpy_ml/rl_models/tests.py b/numpy_ml/plots/rl_plots.py similarity index 96% rename from numpy_ml/rl_models/tests.py rename to numpy_ml/plots/rl_plots.py index 9d73863..30bfbb1 100644 --- a/numpy_ml/rl_models/tests.py +++ b/numpy_ml/plots/rl_plots.py @@ -1,7 +1,8 @@ +# flake8: noqa import gym -from .trainer import Trainer -from .agents import ( +from numpy_ml.rl_models.trainer import Trainer +from numpy_ml.rl_models.agents import ( CrossEntropyAgent, MonteCarloAgent, TemporalDifferenceAgent, diff --git a/numpy_ml/plots/trees_plots.py b/numpy_ml/plots/trees_plots.py new file mode 100644 index 0000000..74de5f5 --- /dev/null +++ b/numpy_ml/plots/trees_plots.py @@ -0,0 +1,161 @@ +# flake8: noqa +import numpy as np + +from sklearn.metrics import accuracy_score, mean_squared_error +from sklearn.datasets import make_blobs, make_regression +from sklearn.model_selection import train_test_split + +import matplotlib.pyplot as plt + +# https://seaborn.pydata.org/generated/seaborn.set_context.html +# https://seaborn.pydata.org/generated/seaborn.set_style.html +import seaborn as sns + +sns.set_style("white") +sns.set_context("paper", font_scale=0.9) + +from numpy_ml.trees import GradientBoostedDecisionTree, DecisionTree, RandomForest + + +def plot(): + fig, axes = plt.subplots(4, 4) + fig.set_size_inches(10, 10) + for ax in axes.flatten(): + n_ex = 100 + n_trees = 50 + n_feats = np.random.randint(2, 100) + max_depth_d = np.random.randint(1, 100) + max_depth_r = np.random.randint(1, 10) + + classifier = np.random.choice([True, False]) + if classifier: + # create classification problem + n_classes = np.random.randint(2, 10) + X, Y = make_blobs(n_samples=n_ex, centers=n_classes, n_features=2) + X, X_test, Y, Y_test = train_test_split(X, Y, test_size=0.3) + n_feats = min(n_feats, X.shape[1]) + + # initialize model + def loss(yp, y): + return accuracy_score(yp, y) + + # initialize model + criterion = np.random.choice(["entropy", "gini"]) + mine = RandomForest( + classifier=classifier, + n_feats=n_feats, + n_trees=n_trees, + criterion=criterion, + max_depth=max_depth_r, + ) + mine_d = DecisionTree( + criterion=criterion, max_depth=max_depth_d, classifier=classifier + ) + mine_g = GradientBoostedDecisionTree( + n_trees=n_trees, + max_depth=max_depth_d, + classifier=classifier, + learning_rate=1, + loss="crossentropy", + step_size="constant", + split_criterion=criterion, + ) + + else: + # create regeression problem + X, Y = make_regression(n_samples=n_ex, n_features=1) + X, X_test, Y, Y_test = train_test_split(X, Y, test_size=0.3) + n_feats = min(n_feats, X.shape[1]) + + # initialize model + criterion = "mse" + loss = mean_squared_error + mine = RandomForest( + criterion=criterion, + n_feats=n_feats, + n_trees=n_trees, + max_depth=max_depth_r, + classifier=classifier, + ) + mine_d = DecisionTree( + criterion=criterion, max_depth=max_depth_d, classifier=classifier + ) + mine_g = GradientBoostedDecisionTree( + n_trees=n_trees, + max_depth=max_depth_d, + classifier=classifier, + learning_rate=1, + loss="mse", + step_size="adaptive", + split_criterion=criterion, + ) + + # fit 'em + mine.fit(X, Y) + mine_d.fit(X, Y) + mine_g.fit(X, Y) + + # get preds on test set + y_pred_mine_test = mine.predict(X_test) + y_pred_mine_test_d = mine_d.predict(X_test) + y_pred_mine_test_g = mine_g.predict(X_test) + + loss_mine_test = loss(y_pred_mine_test, Y_test) + loss_mine_test_d = loss(y_pred_mine_test_d, Y_test) + loss_mine_test_g = loss(y_pred_mine_test_g, Y_test) + + if classifier: + entries = [ + ("RF", loss_mine_test, y_pred_mine_test), + ("DT", loss_mine_test_d, y_pred_mine_test_d), + ("GB", loss_mine_test_g, y_pred_mine_test_g), + ] + (lbl, test_loss, preds) = entries[np.random.randint(3)] + ax.set_title("{} Accuracy: {:.2f}%".format(lbl, test_loss * 100)) + for i in np.unique(Y_test): + ax.scatter( + X_test[preds == i, 0].flatten(), + X_test[preds == i, 1].flatten(), + # s=0.5, + ) + else: + X_ax = np.linspace( + np.min(X_test.flatten()) - 1, np.max(X_test.flatten()) + 1, 100 + ).reshape(-1, 1) + y_pred_mine_test = mine.predict(X_ax) + y_pred_mine_test_d = mine_d.predict(X_ax) + y_pred_mine_test_g = mine_g.predict(X_ax) + + ax.scatter(X_test.flatten(), Y_test.flatten(), c="b", alpha=0.5) + # s=0.5) + ax.plot( + X_ax.flatten(), + y_pred_mine_test_g.flatten(), + # linewidth=0.5, + label="GB".format(n_trees, n_feats, max_depth_d), + color="red", + ) + ax.plot( + X_ax.flatten(), + y_pred_mine_test.flatten(), + # linewidth=0.5, + label="RF".format(n_trees, n_feats, max_depth_r), + color="cornflowerblue", + ) + ax.plot( + X_ax.flatten(), + y_pred_mine_test_d.flatten(), + # linewidth=0.5, + label="DT".format(max_depth_d), + color="yellowgreen", + ) + ax.set_title( + "GB: {:.1f} / RF: {:.1f} / DT: {:.1f} ".format( + loss_mine_test_g, loss_mine_test, loss_mine_test_d + ) + ) + ax.legend() + ax.xaxis.set_ticklabels([]) + ax.yaxis.set_ticklabels([]) + plt.savefig("plot.png", dpi=300) + plt.close("all") diff --git a/numpy_ml/preprocessing/nlp.py b/numpy_ml/preprocessing/nlp.py index 4533e36..68fc28e 100644 --- a/numpy_ml/preprocessing/nlp.py +++ b/numpy_ml/preprocessing/nlp.py @@ -1,3 +1,4 @@ +"""Common preprocessing utilities for working with text data""" import re import heapq import os.path as op @@ -9,328 +10,327 @@ # This list of English stop words is taken from the "Glasgow Information # Retrieval Group". The original list can be found at # http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words -_STOP_WORDS = set( - [ - "a", - "about", - "above", - "across", - "after", - "afterwards", - "again", - "against", - "all", - "almost", - "alone", - "along", - "already", - "also", - "although", - "always", - "am", - "among", - "amongst", - "amoungst", - "amount", - "an", - "and", - "another", - "any", - "anyhow", - "anyone", - "anything", - "anyway", - "anywhere", - "are", - "around", - "as", - "at", - "back", - "be", - "became", - "because", - "become", - "becomes", - "becoming", - "been", - "before", - "beforehand", - "behind", - "being", - "below", - "beside", - "besides", - "between", - "beyond", - "bill", - "both", - "bottom", - "but", - "by", - "call", - "can", - "cannot", - "cant", - "co", - "con", - "could", - "couldnt", - "cry", - "de", - "describe", - "detail", - "do", - "done", - "down", - "due", - "during", - "each", - "eg", - "eight", - "either", - "eleven", - "else", - "elsewhere", - "empty", - "enough", - "etc", - "even", - "ever", - "every", - "everyone", - "everything", - "everywhere", - "except", - "few", - "fifteen", - "fifty", - "fill", - "find", - "fire", - "first", - "five", - "for", - "former", - "formerly", - "forty", - "found", - "four", - "from", - "front", - "full", - "further", - "get", - "give", - "go", - "had", - "has", - "hasnt", - "have", - "he", - "hence", - "her", - "here", - "hereafter", - "hereby", - "herein", - "hereupon", - "hers", - "herself", - "him", - "himself", - "his", - "how", - "however", - "hundred", - "i", - "ie", - "if", - "in", - "inc", - "indeed", - "interest", - "into", - "is", - "it", - "its", - "itself", - "keep", - "last", - "latter", - "latterly", - "least", - "less", - "ltd", - "made", - "many", - "may", - "me", - "meanwhile", - "might", - "mill", - "mine", - "more", - "moreover", - "most", - "mostly", - "move", - "much", - "must", - "my", - "myself", - "name", - "namely", - "neither", - "never", - "nevertheless", - "next", - "nine", - "no", - "nobody", - "none", - "noone", - "nor", - "not", - "nothing", - "now", - "nowhere", - "of", - "off", - "often", - "on", - "once", - "one", - "only", - "onto", - "or", - "other", - "others", - "otherwise", - "our", - "ours", - "ourselves", - "out", - "over", - "own", - "part", - "per", - "perhaps", - "please", - "put", - "rather", - "re", - "same", - "see", - "seem", - "seemed", - "seeming", - "seems", - "serious", - "several", - "she", - "should", - "show", - "side", - "since", - "sincere", - "six", - "sixty", - "so", - "some", - "somehow", - "someone", - "something", - "sometime", - "sometimes", - "somewhere", - "still", - "such", - "system", - "take", - "ten", - "than", - "that", - "the", - "their", - "them", - "themselves", - "then", - "thence", - "there", - "thereafter", - "thereby", - "therefore", - "therein", - "thereupon", - "these", - "they", - "thick", - "thin", - "third", - "this", - "those", - "though", - "three", - "through", - "throughout", - "thru", - "thus", - "to", - "together", - "too", - "top", - "toward", - "towards", - "twelve", - "twenty", - "two", - "un", - "under", - "until", - "up", - "upon", - "us", - "very", - "via", - "was", - "we", - "well", - "were", - "what", - "whatever", - "when", - "whence", - "whenever", - "where", - "whereafter", - "whereas", - "whereby", - "wherein", - "whereupon", - "wherever", - "whether", - "which", - "while", - "whither", - "who", - "whoever", - "whole", - "whom", - "whose", - "why", - "will", - "with", - "within", - "without", - "would", - "yet", - "you", - "your", - "yours", - "yourself", - "yourselves", - ] -) +_STOP_WORDS = { + "a", + "about", + "above", + "across", + "after", + "afterwards", + "again", + "against", + "all", + "almost", + "alone", + "along", + "already", + "also", + "although", + "always", + "am", + "among", + "amongst", + "amoungst", + "amount", + "an", + "and", + "another", + "any", + "anyhow", + "anyone", + "anything", + "anyway", + "anywhere", + "are", + "around", + "as", + "at", + "back", + "be", + "became", + "because", + "become", + "becomes", + "becoming", + "been", + "before", + "beforehand", + "behind", + "being", + "below", + "beside", + "besides", + "between", + "beyond", + "bill", + "both", + "bottom", + "but", + "by", + "call", + "can", + "cannot", + "cant", + "co", + "con", + "could", + "couldnt", + "cry", + "de", + "describe", + "detail", + "do", + "done", + "down", + "due", + "during", + "each", + "eg", + "eight", + "either", + "eleven", + "else", + "elsewhere", + "empty", + "enough", + "etc", + "even", + "ever", + "every", + "everyone", + "everything", + "everywhere", + "except", + "few", + "fifteen", + "fifty", + "fill", + "find", + "fire", + "first", + "five", + "for", + "former", + "formerly", + "forty", + "found", + "four", + "from", + "front", + "full", + "further", + "get", + "give", + "go", + "had", + "has", + "hasnt", + "have", + "he", + "hence", + "her", + "here", + "hereafter", + "hereby", + "herein", + "hereupon", + "hers", + "herself", + "him", + "himself", + "his", + "how", + "however", + "hundred", + "i", + "ie", + "if", + "in", + "inc", + "indeed", + "interest", + "into", + "is", + "it", + "its", + "itself", + "keep", + "last", + "latter", + "latterly", + "least", + "less", + "ltd", + "made", + "many", + "may", + "me", + "meanwhile", + "might", + "mill", + "mine", + "more", + "moreover", + "most", + "mostly", + "move", + "much", + "must", + "my", + "myself", + "name", + "namely", + "neither", + "never", + "nevertheless", + "next", + "nine", + "no", + "nobody", + "none", + "noone", + "nor", + "not", + "nothing", + "now", + "nowhere", + "of", + "off", + "often", + "on", + "once", + "one", + "only", + "onto", + "or", + "other", + "others", + "otherwise", + "our", + "ours", + "ourselves", + "out", + "over", + "own", + "part", + "per", + "perhaps", + "please", + "put", + "rather", + "re", + "same", + "see", + "seem", + "seemed", + "seeming", + "seems", + "serious", + "several", + "she", + "should", + "show", + "side", + "since", + "sincere", + "six", + "sixty", + "so", + "some", + "somehow", + "someone", + "something", + "sometime", + "sometimes", + "somewhere", + "still", + "such", + "system", + "take", + "ten", + "than", + "that", + "the", + "their", + "them", + "themselves", + "then", + "thence", + "there", + "thereafter", + "thereby", + "therefore", + "therein", + "thereupon", + "these", + "they", + "thick", + "thin", + "third", + "this", + "those", + "though", + "three", + "through", + "throughout", + "thru", + "thus", + "to", + "together", + "too", + "top", + "toward", + "towards", + "twelve", + "twenty", + "two", + "un", + "under", + "until", + "up", + "upon", + "us", + "very", + "via", + "was", + "we", + "well", + "were", + "what", + "whatever", + "when", + "whence", + "whenever", + "where", + "whereafter", + "whereas", + "whereby", + "wherein", + "whereupon", + "wherever", + "whether", + "which", + "while", + "whither", + "who", + "whoever", + "whole", + "whom", + "whose", + "why", + "will", + "with", + "within", + "without", + "would", + "yet", + "you", + "your", + "yours", + "yourself", + "yourselves", +} + _PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" _WORD_REGEX = re.compile(r"(?u)\b\w\w+\b") # sklearn default @@ -386,21 +386,25 @@ def __init__(self, key, val): self.right = None def __gt__(self, other): + """Greater than""" if not isinstance(other, Node): return -1 return self.val > other.val def __ge__(self, other): + """Greater than or equal to""" if not isinstance(other, Node): return -1 return self.val >= other.val def __lt__(self, other): + """Less than""" if not isinstance(other, Node): return -1 return self.val < other.val def __le__(self, other): + """Less than or equal to""" if not isinstance(other, Node): return -1 return self.val <= other.val @@ -482,10 +486,12 @@ def inverse_transform(self, codes): @property def tokens(self): + """A list the unique tokens in `text`""" return list(self._item2code.keys()) @property def codes(self): + """A list with the Huffman code for each unique token in `text`""" return list(self._code2item.keys()) def _counter(self, text): @@ -552,6 +558,7 @@ def __init__(self, word): self.word = word def __repr__(self): + """A string representation of the token""" return "Token(word='{}', count={})".format(self.word, self.count) @@ -566,7 +573,7 @@ def __init__( input_type="filename", filter_stopwords=True, ): - """ + r""" An object for compiling and encoding the term-frequency inverse-document-frequency (TF-IDF) representation of the tokens in a text corpus. @@ -578,8 +585,8 @@ def __init__( corpus, :math:`D = \{d_1, \ldots, d_N\}`, we have: .. math:: - \\text{TF}(w, d) &= \\text{num. occurences of }`w`\\text{ in document }`d` \\\\ - \\text{IDF}(w, D) &= \log \\frac{|D|}{|\{ d \in D: t \in d \}|} + \text{TF}(w, d) &= \text{num. occurences of }w \text{ in document }d \\ + \text{IDF}(w, D) &= \log \frac{|D|}{|\{ d \in D: t \in d \}|} Parameters ---------- @@ -694,7 +701,7 @@ def fit(self, corpus_seq, encoding="utf-8-sig"): doc_count = {} idx2doc[d_ix] = doc if H["input_type"] == "files" else None token2idx, idx2token, tokens, doc_count = self._encode_document( - doc, token2idx, idx2token, tokens, doc_count, bol_ix, eol_ix + doc, token2idx, idx2token, tokens, doc_count, bol_ix, eol_ix, ) term_freq[d_ix] = doc_count @@ -720,11 +727,9 @@ def fit(self, corpus_seq, encoding="utf-8-sig"): self._calc_idf() def _encode_document( - self, doc, word2idx, idx2word, tokens, doc_count, bol_ix, eol_ix + self, doc, word2idx, idx2word, tokens, doc_count, bol_ix, eol_ix, ): - """ - Perform tokenization and compute token counts for a single document - """ + """Perform tokenization and compute token counts for a single document""" H = self.hyperparameters lowercase = H["lowercase"] filter_stop = H["filter_stopwords"] @@ -816,7 +821,7 @@ def _drop_low_freq_tokens(self): unk_idx = 0 word2idx = {"": 0, "": 1, "": 2} idx2word = {0: "", 1: "", 2: ""} - special = set(["", "", ""]) + special = {"", "", ""} for tt in self._tokens: if tt.word not in special: @@ -895,7 +900,7 @@ def _calc_idf(self): for word, w_ix in self.token2idx.items(): d_count = int(smooth_idf) d_count += np.sum([1 if w_ix in tf[d_ix] else 0 for d_ix in doc_idxs]) - inv_doc_freq[w_ix] = np.log(D / d_count) + 1 + inv_doc_freq[w_ix] = 1 if d_count == 0 else np.log(D / d_count) + 1 self.inv_doc_freq = inv_doc_freq def transform(self, ignore_special_chars=True): @@ -944,7 +949,7 @@ def transform(self, ignore_special_chars=True): class Vocabulary: def __init__( - self, lowercase=True, min_count=None, max_tokens=None, filter_stopwords=True + self, lowercase=True, min_count=None, max_tokens=None, filter_stopwords=True, ): """ An object for compiling and encoding the unique tokens in a text corpus. @@ -977,15 +982,22 @@ def __init__( } def __len__(self): + """Return the number of tokens in the vocabulary""" return len(self._tokens) def __iter__(self): + """Return an iterator over the tokens in the vocabulary""" return iter(self._tokens) def __contains__(self, word): + """Assert whether `word` is a token in the vocabulary""" return word in self.token2idx def __getitem__(self, key): + """ + Return the token (if key is an integer) or the index (if key is a string) + for the key in the vocabulary, if it exists. + """ if isinstance(key, str): return self._tokens[self.token2idx[key]] if isinstance(key, int): @@ -1014,7 +1026,7 @@ def words_with_count(self, k): """Return all tokens that occur `k` times in the corpus""" return [w for w, c in self.counts.items() if c == k] - def filter(self, words, unk=True): + def filter(self, words, unk=True): # noqa: A003 """ Filter or replace any word in `words` that does not occur in `Vocabulary` @@ -1201,7 +1213,7 @@ def _drop_low_freq_tokens(self): tokens = [unk_token, eol_token, bol_token] word2idx = {"": 0, "": 1, "": 2} idx2word = {0: "", 1: "", 2: ""} - special = set(["", "", ""]) + special = {"", "", ""} for tt in self._tokens: if tt.word not in special: diff --git a/numpy_ml/rl_models/rl_utils.py b/numpy_ml/rl_models/rl_utils.py index 2d80680..245bb8e 100644 --- a/numpy_ml/rl_models/rl_utils.py +++ b/numpy_ml/rl_models/rl_utils.py @@ -1,11 +1,28 @@ +"""Utilities for training and evaluating RL models on OpenAI gym environments""" +import warnings from itertools import product from collections import defaultdict import numpy as np -import gym - -from .tiles.tiles3 import tiles, IHT +from numpy_ml.utils.testing import DependencyWarning +from numpy_ml.rl_models.tiles.tiles3 import tiles, IHT + +NO_PD = False +try: + import pandas as pd +except ModuleNotFoundError: + NO_PD = True + +try: + import gym +except ModuleNotFoundError: + fstr = ( + "Agents in `numpy_ml.rl_models` use the OpenAI gym for training. " + "To install the gym environments, run `pip install gym`. For more" + " information, see https://github.com/openai/gym." + ) + warnings.warn(fstr, DependencyWarning) class EnvModel(object): @@ -29,23 +46,24 @@ def __init__(self): self._model = defaultdict(lambda: defaultdict(lambda: 0)) def __setitem__(self, key, value): + """Set self[key] to value""" s, a, r, s_ = key self._model[(s, a)][(r, s_)] = value def __getitem__(self, key): + """Return the value associated with key""" s, a, r, s_ = key return self._model[(s, a)][(r, s_)] def __contains__(self, key): + """True if EnvModel contains `key`, else False""" s, a, r, s_ = key p1 = (s, a) in self.state_action_pairs() p2 = (r, s_) in self.reward_outcome_pairs() return p1 and p2 def state_action_pairs(self): - """ - Return all (state, action) pairs in the environment model - """ + """Return all (state, action) pairs in the environment model""" return list(self._model.keys()) def reward_outcome_pairs(self, s, a): @@ -166,7 +184,7 @@ def tile_state_space( scale = 1.0 / obs_range # scale (state-)observation vector - scale_obs = lambda obs: obs * scale + scale_obs = lambda obs: obs * scale # noqa: E731 n_tiles = np.prod(grid_size) * n_tilings n_states = np.prod([n_tiles - i for i in range(n_tilings)]) @@ -180,16 +198,12 @@ def encode_obs_as_tile(obs): def get_gym_environs(): - """ List all valid OpenAI ``gym`` environment ids. """ + """List all valid OpenAI ``gym`` environment ids""" return [e.id for e in gym.envs.registry.all()] def get_gym_stats(): - """ Return a pandas DataFrame of the environment IDs. """ - try: - import pandas as pd - except: - raise ImportError("Cannot import `pandas`; unable to run `get_gym_stats`") + """Return a pandas DataFrame of the environment IDs.""" df = [] for e in gym.envs.registry.all(): print(e.id) @@ -211,7 +225,7 @@ def get_gym_stats(): "tuple_actions", "tuple_observations", ] - return pd.DataFrame(df)[cols] + return df if NO_PD else pd.DataFrame(df)[cols] def is_tuple(env): @@ -305,13 +319,13 @@ def is_continuous(env, tuple_action, tuple_obs): Continuous = gym.spaces.box.Box if tuple_obs: spaces = env.observation_space.spaces - cont_obs = all([isinstance(s, Continuous) for s in spaces]) + cont_obs = all(isinstance(s, Continuous) for s in spaces) else: cont_obs = isinstance(env.observation_space, Continuous) if tuple_action: spaces = env.action_space.spaces - cont_action = all([isinstance(s, Continuous) for s in spaces]) + cont_action = all(isinstance(s, Continuous) for s in spaces) else: cont_action = isinstance(env.action_space, Continuous) return cont_action, cont_obs @@ -432,7 +446,7 @@ def env_stats(env): cont_action, cont_obs = is_continuous(env, tuple_action, tuple_obs) n_actions_per_dim, action_ids, action_dim = action_stats( - env, md_action, cont_action + env, md_action, cont_action, ) n_obs_per_dim, obs_ids, obs_dim = obs_stats(env, md_obs, cont_obs) diff --git a/numpy_ml/tests/__init__.py b/numpy_ml/tests/__init__.py new file mode 100644 index 0000000..20ff959 --- /dev/null +++ b/numpy_ml/tests/__init__.py @@ -0,0 +1 @@ +"""Unit tests for various numpy-ml modules""" diff --git a/numpy_ml/neural_nets/tests/torch_models.py b/numpy_ml/tests/nn_torch_models.py similarity index 97% rename from numpy_ml/neural_nets/tests/torch_models.py rename to numpy_ml/tests/nn_torch_models.py index 8de9212..a5ae3dc 100644 --- a/numpy_ml/neural_nets/tests/torch_models.py +++ b/numpy_ml/tests/nn_torch_models.py @@ -1,3 +1,5 @@ +# flake8: noqa + import torch import torch.nn as nn import torch.nn.functional as F @@ -30,7 +32,7 @@ def get_grad(z): def torch_xe_grad(y, z): z = torch.autograd.Variable(torch.FloatTensor(z), requires_grad=True) y = torch.LongTensor(y.argmax(axis=1)) - loss = F.cross_entropy(z, y, size_average=False).sum() + loss = F.cross_entropy(z, y, reduction="sum") loss.backward() grad = z.grad.numpy() return grad @@ -40,7 +42,7 @@ def torch_mse_grad(y, z, act_fn): y = torch.FloatTensor(y) z = torch.autograd.Variable(torch.FloatTensor(z), requires_grad=True) y_pred = act_fn(z) - loss = F.mse_loss(y_pred, y, size_average=False).sum() + loss = F.mse_loss(y_pred, y, reduction="sum") # size_average=False).sum() loss.backward() grad = z.grad.numpy() return grad @@ -57,7 +59,7 @@ def extract_grads(self, X, X_recon, t_mean, t_log_var): t_mean = torchify(t_mean) t_log_var = torchify(t_log_var) - BCE = torch.sum(F.binary_cross_entropy(X_recon, X, reduce=False), dim=1) + BCE = torch.sum(F.binary_cross_entropy(X_recon, X, reduction="none"), dim=1) # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 @@ -622,8 +624,8 @@ def forward(self, X_main, X_skip): self.conv_dilation_out = self.conv_dilation(self.X_main) self.conv_dilation_out.retain_grad() - self.tanh_out = F.tanh(self.conv_dilation_out) - self.sigm_out = F.sigmoid(self.conv_dilation_out) + self.tanh_out = torch.tanh(self.conv_dilation_out) + self.sigm_out = torch.sigmoid(self.conv_dilation_out) self.tanh_out.retain_grad() self.sigm_out.retain_grad() @@ -1869,7 +1871,7 @@ def LinearLayer(name, n_in, n_out, inputs, w_initialization): def Generator(n_samples, X_real, params=None): n_feats = 2 W1 = W2 = W3 = W4 = "he" - noise = tf.random_normal([n_samples, 2]) + noise = tf.random.normal([n_samples, 2]) if params is not None: noise = tf.convert_to_tensor(params["noise"], dtype="float32") W1 = params["generator"]["FC1"]["W"] @@ -1940,6 +1942,8 @@ def Discriminator(inputs, params=None): def WGAN_GP_tf(X, lambda_, params, batch_size): + tf.compat.v1.disable_eager_execution() + batch_size = X.shape[0] # get alpha value @@ -1947,7 +1951,7 @@ def WGAN_GP_tf(X, lambda_, params, batch_size): c_updates_per_epoch = params["c_updates_per_epoch"] alpha = tf.convert_to_tensor(params["alpha"], dtype="float32") - X_real = tf.placeholder(tf.float32, shape=[None, params["n_in"]]) + X_real = tf.compat.v1.placeholder(tf.float32, shape=[None, params["n_in"]]) X_fake, G_out_X_fake, G_weights = Generator(batch_size, X_real, params) Y_real, C_out_Y_real, C_Y_real_weights = Discriminator(X_real, params) @@ -1966,7 +1970,7 @@ def WGAN_GP_tf(X, lambda_, params, batch_size): gradInterp = tf.gradients(Y_interp, [X_interp])[0] norm_gradInterp = tf.sqrt( - tf.reduce_sum(tf.square(gradInterp), reduction_indices=[1]) + tf.compat.v1.reduce_sum(tf.square(gradInterp), reduction_indices=[1]) ) gradient_penalty = tf.reduce_mean((norm_gradInterp - 1) ** 2) C_loss += lambda_ * gradient_penalty @@ -1986,8 +1990,8 @@ def WGAN_GP_tf(X, lambda_, params, batch_size): dC_gradInterp = tf.gradients(C_loss, [gradInterp])[0] dG_Y_fake = tf.gradients(G_loss, [Y_fake])[0] - with tf.Session() as session: - session.run(tf.global_variables_initializer()) + with tf.compat.v1.Session() as session: + session.run(tf.compat.v1.global_variables_initializer()) for iteration in range(n_steps): # Train critic @@ -2189,13 +2193,25 @@ def TFNCELoss(X, target_word, L): from tensorflow.python.ops.nn_impl import _compute_sampled_logits from tensorflow.python.ops.nn_impl import sigmoid_cross_entropy_with_logits - in_embed = tf.placeholder(tf.float32, shape=X.shape) - in_bias = tf.placeholder(tf.float32, shape=L.parameters["b"].flatten().shape) - in_weights = tf.placeholder(tf.float32, shape=L.parameters["W"].shape) - in_target_word = tf.placeholder(tf.int64) - in_neg_samples = tf.placeholder(tf.int32) - in_target_prob = tf.placeholder(tf.float32) - in_neg_samp_prob = tf.placeholder(tf.float32) + tf.compat.v1.disable_eager_execution() + + in_embed = tf.compat.v1.placeholder(tf.float32, shape=X.shape) + in_bias = tf.compat.v1.placeholder( + tf.float32, shape=L.parameters["b"].flatten().shape + ) + in_weights = tf.compat.v1.placeholder(tf.float32, shape=L.parameters["W"].shape) + in_target_word = tf.compat.v1.placeholder(tf.int64) + in_neg_samples = tf.compat.v1.placeholder(tf.int32) + in_target_prob = tf.compat.v1.placeholder(tf.float32) + in_neg_samp_prob = tf.compat.v1.placeholder(tf.float32) + + # in_embed = tf.keras.Input(dtype=tf.float32, shape=X.shape) + # in_bias = tf.keras.Input(dtype=tf.float32, shape=L.parameters["b"].flatten().shape) + # in_weights = tf.keras.Input(dtype=tf.float32, shape=L.parameters["W"].shape) + # in_target_word = tf.keras.Input(dtype=tf.int64, shape=()) + # in_neg_samples = tf.keras.Input(dtype=tf.int32, shape=()) + # in_target_prob = tf.keras.Input(dtype=tf.float32, shape=()) + # in_neg_samp_prob = tf.keras.Input(dtype=tf.float32, shape=()) feed = { in_embed: X, @@ -2239,8 +2255,8 @@ def TFNCELoss(X, target_word, L): labels=sampled_labels, logits=sampled_logits ) - with tf.Session() as session: - session.run(tf.global_variables_initializer()) + with tf.compat.v1.Session() as session: + session.run(tf.compat.v1.global_variables_initializer()) ( _final_loss, _nce_unreduced, @@ -2263,7 +2279,7 @@ def TFNCELoss(X, target_word, L): ], feed_dict=feed, ) - tf.reset_default_graph() + tf.compat.v1.reset_default_graph() return { "final_loss": _final_loss, "nce_unreduced": _nce_unreduced, diff --git a/numpy_ml/ngram/tests.py b/numpy_ml/tests/test_ngram.py similarity index 92% rename from numpy_ml/ngram/tests.py rename to numpy_ml/tests/test_ngram.py index 09f03b4..0fd4252 100644 --- a/numpy_ml/ngram/tests.py +++ b/numpy_ml/tests/test_ngram.py @@ -1,8 +1,12 @@ +# flake8: noqa +import tempfile + import nltk import numpy as np from ..preprocessing.nlp import tokenize_words -from .ngram import AdditiveNGram, MLENGram +from ..ngram import AdditiveNGram, MLENGram +from ..utils.testing import random_paragraph class MLEGold: @@ -23,8 +27,6 @@ def __init__( "filter_punctuation": filter_punctuation, } - super().__init__() - def train(self, corpus_fp, vocab=None, encoding=None): N = self.N H = self.hyperparameters @@ -120,8 +122,6 @@ def __init__( "filter_punctuation": filter_punctuation, } - super().__init__() - def train(self, corpus_fp, vocab=None, encoding=None): N = self.N H = self.hyperparameters @@ -204,8 +204,10 @@ def test_mle(): gold = MLEGold(N, unk=True, filter_stopwords=False, filter_punctuation=False) mine = MLENGram(N, unk=True, filter_stopwords=False, filter_punctuation=False) - gold.train("russell.txt", encoding="utf-8-sig") - mine.train("russell.txt", encoding="utf-8-sig") + with tempfile.NamedTemporaryFile() as temp: + temp.write(bytes(" ".join(random_paragraph(1000)), encoding="utf-8-sig")) + gold.train(temp.name, encoding="utf-8-sig") + mine.train(temp.name, encoding="utf-8-sig") for k in mine.counts[N].keys(): if k[0] == k[1] and k[0] in ("", ""): @@ -232,8 +234,10 @@ def test_additive(): N, K, unk=True, filter_stopwords=False, filter_punctuation=False ) - gold.train("russell.txt", encoding="utf-8-sig") - mine.train("russell.txt", encoding="utf-8-sig") + with tempfile.NamedTemporaryFile() as temp: + temp.write(bytes(" ".join(random_paragraph(1000)), encoding="utf-8-sig")) + gold.train(temp.name, encoding="utf-8-sig") + mine.train(temp.name, encoding="utf-8-sig") for k in mine.counts[N].keys(): if k[0] == k[1] and k[0] in ("", ""): diff --git a/numpy_ml/neural_nets/tests/tests.py b/numpy_ml/tests/test_nn.py similarity index 91% rename from numpy_ml/neural_nets/tests/tests.py rename to numpy_ml/tests/test_nn.py index bdfe950..1f00680 100644 --- a/numpy_ml/neural_nets/tests/tests.py +++ b/numpy_ml/tests/test_nn.py @@ -14,14 +14,22 @@ import torch.nn as nn import torch.nn.functional as F -from ..utils import calc_pad_dims_2D, conv2D_naive, conv2D, pad2D, pad1D -from ...utils.testing import ( +import tensorflow.keras.datasets.mnist as mnist + +from numpy_ml.neural_nets.utils import ( + calc_pad_dims_2D, + conv2D_naive, + conv2D, + pad2D, + pad1D, +) +from numpy_ml.utils.testing import ( random_one_hot_matrix, random_stochastic_matrix, random_tensor, ) -from .torch_models import ( +from .nn_torch_models import ( TFNCELoss, WGAN_GP_tf, torch_xe_grad, @@ -70,171 +78,15 @@ def err_fmt(params, golds, ix, warn_str=""): return err_msg -####################################################################### -# Test Suite # -####################################################################### - - -def test_everything(N=50): - test_losses(N=N) - test_activations(N=N) - test_layers(N=N) - test_utils(N=N) - test_modules(N=N) - - -def test_losses(N=50): - print("Testing SquaredError loss") - time.sleep(1) - test_squared_error(N) - test_squared_error_grad(N) - - print("Testing CrossEntropy loss") - time.sleep(1) - test_cross_entropy(N) - test_cross_entropy_grad(N) - - print("Testing VAELoss") - time.sleep(1) - test_VAE_loss(N) - - print("Testing WGAN_GPLoss") - time.sleep(1) - test_WGAN_GP_loss(N) - - print("Testing NCELoss") - time.sleep(1) - test_NCELoss(N) - - -def test_activations(N=50): - print("Testing Sigmoid activation") - time.sleep(1) - test_sigmoid_activation(N) - test_sigmoid_grad(N) - - print("Testing Softmax activation") - time.sleep(1) - test_softmax_activation(N) - test_softmax_grad(N) - - print("Testing Tanh activation") - time.sleep(1) - test_tanh_grad(N) - - print("Testing ReLU activation") - time.sleep(1) - test_relu_activation(N) - test_relu_grad(N) - - print("Testing ELU activation") - time.sleep(1) - test_elu_activation(N) - test_elu_grad(N) - - print("Testing SoftPlus activation") - time.sleep(1) - test_softplus_activation(N) - test_softplus_grad(N) - - -def test_layers(N=50): - print("Testing FullyConnected layer") - time.sleep(1) - test_FullyConnected(N) - - print("Testing Conv1D layer") - time.sleep(1) - test_Conv1D(N) - - print("Testing Conv2D layer") - time.sleep(1) - test_Conv2D(N) - - print("Testing Pool2D layer") - time.sleep(1) - test_Pool2D(N) - - print("Testing BatchNorm1D layer") - time.sleep(1) - test_BatchNorm1D(N) - - print("Testing BatchNorm2D layer") - time.sleep(1) - test_BatchNorm2D(N) - - print("Testing LayerNorm1D layer") - time.sleep(1) - test_LayerNorm1D(N) - - print("Testing LayerNorm2D layer") - time.sleep(1) - test_LayerNorm2D(N) - - print("Testing Deconv2D layer") - time.sleep(1) - test_Deconv2D(N) - - print("Testing Add layer") - time.sleep(1) - test_AddLayer(N) - - print("Testing Multiply layer") - time.sleep(1) - test_MultiplyLayer(N) - - print("Testing LSTMCell layer") - time.sleep(1) - test_LSTMCell(N) - - print("Testing RNNCell layer") - time.sleep(1) - test_RNNCell(N) - - print("Testing DotProductAttention layer") - time.sleep(1) - test_DPAttention(N) - - -def test_utils(N=50): - print("Testing pad1D util") - time.sleep(1) - test_pad1D(N) - - print("Testing conv2D util") - time.sleep(1) - test_conv(N) - - -def test_modules(N=50): - print("Testing MultiHeadedAttentionModule") - time.sleep(1) - test_MultiHeadedAttentionModule(N) - - print("Testing BidirectionalLSTM module") - time.sleep(1) - test_BidirectionalLSTM(N) - - print("Testing WaveNet module") - time.sleep(1) - test_WaveNetModule(N) - - print("Testing SkipConnectionIdentity module") - time.sleep(1) - test_SkipConnectionIdentityModule(N) - - print("Testing SkipConnectionConv module") - time.sleep(1) - test_SkipConnectionConvModule(N) - - ####################################################################### # Loss Functions # ####################################################################### -def test_squared_error(N=None): - from ..losses import SquaredError +def test_squared_error(N=15): + from numpy_ml.neural_nets.losses import SquaredError + + np.random.seed(12345) N = np.inf if N is None else N @@ -264,8 +116,10 @@ def test_squared_error(N=None): i += 1 -def test_cross_entropy(N=None): - from ..losses import CrossEntropy +def test_cross_entropy(N=15): + from numpy_ml.neural_nets.losses import CrossEntropy + + np.random.seed(12345) N = np.inf if N is None else N @@ -292,17 +146,20 @@ def test_cross_entropy(N=None): i += 1 -def test_VAE_loss(N=None): - from ..losses import VAELoss +def test_VAE_loss(N=15): + from numpy_ml.neural_nets.losses import VAELoss + + np.random.seed(12345) N = np.inf if N is None else N + eps = np.finfo(float).eps i = 1 while i < N: n_ex = np.random.randint(1, 10) t_dim = np.random.randint(2, 10) t_mean = random_tensor([n_ex, t_dim], standardize=True) - t_log_var = np.log(np.abs(random_tensor([n_ex, t_dim], standardize=True))) + t_log_var = np.log(np.abs(random_tensor([n_ex, t_dim], standardize=True) + eps)) im_cols, im_rows = np.random.randint(2, 40), np.random.randint(2, 40) X = np.random.rand(n_ex, im_rows * im_cols) X_recon = np.random.rand(n_ex, im_rows * im_cols) @@ -331,8 +188,10 @@ def test_VAE_loss(N=None): i += 1 -def test_WGAN_GP_loss(N=None): - from ..losses import WGAN_GPLoss +def test_WGAN_GP_loss(N=5): + from numpy_ml.neural_nets.losses import WGAN_GPLoss + + np.random.seed(12345) N = np.inf if N is None else N @@ -381,8 +240,8 @@ def test_WGAN_GP_loss(N=None): i += 1 -def test_NCELoss(N=None): - from ..losses import NCELoss +def test_NCELoss(N=1): + from numpy_ml.neural_nets.losses import NCELoss from numpy_ml.utils.data_structures import DiscreteSampler np.random.seed(12345) @@ -479,9 +338,11 @@ def test_NCELoss(N=None): ####################################################################### -def test_squared_error_grad(N=None): - from ..losses import SquaredError - from ..activations import Tanh +def test_squared_error_grad(N=15): + from numpy_ml.neural_nets.losses import SquaredError + from numpy_ml.neural_nets.activations import Tanh + + np.random.seed(12345) N = np.inf if N is None else N @@ -500,15 +361,17 @@ def test_squared_error_grad(N=None): y_pred = act.fn(z) assert_almost_equal( - mine.grad(y, y_pred, z, act), 0.5 * gold(y, z, F.tanh), decimal=4 + mine.grad(y, y_pred, z, act), 0.5 * gold(y, z, torch.tanh), decimal=4 ) print("PASSED") i += 1 -def test_cross_entropy_grad(N=None): - from ..losses import CrossEntropy - from ..layers import Softmax +def test_cross_entropy_grad(N=15): + from numpy_ml.neural_nets.losses import CrossEntropy + from numpy_ml.neural_nets.layers import Softmax + + np.random.seed(12345) N = np.inf if N is None else N @@ -537,8 +400,10 @@ def test_cross_entropy_grad(N=None): ####################################################################### -def test_sigmoid_activation(N=None): - from ..activations import Sigmoid +def test_sigmoid_activation(N=15): + from numpy_ml.neural_nets.activations import Sigmoid + + np.random.seed(12345) N = np.inf if N is None else N @@ -554,8 +419,10 @@ def test_sigmoid_activation(N=None): i += 1 -def test_elu_activation(N=None): - from ..activations import ELU +def test_elu_activation(N=15): + from numpy_ml.neural_nets.activations import ELU + + np.random.seed(12345) N = np.inf if N is None else N @@ -574,8 +441,10 @@ def test_elu_activation(N=None): i += 1 -def test_softmax_activation(N=None): - from ..layers import Softmax +def test_softmax_activation(N=15): + from numpy_ml.neural_nets.layers import Softmax + + np.random.seed(12345) N = np.inf if N is None else N @@ -591,8 +460,10 @@ def test_softmax_activation(N=None): i += 1 -def test_relu_activation(N=None): - from ..activations import ReLU +def test_relu_activation(N=15): + from numpy_ml.neural_nets.activations import ReLU + + np.random.seed(12345) N = np.inf if N is None else N @@ -608,8 +479,10 @@ def test_relu_activation(N=None): i += 1 -def test_softplus_activation(N=None): - from ..activations import SoftPlus +def test_softplus_activation(N=15): + from numpy_ml.neural_nets.activations import SoftPlus + + np.random.seed(12345) N = np.inf if N is None else N @@ -630,13 +503,15 @@ def test_softplus_activation(N=None): ####################################################################### -def test_sigmoid_grad(N=None): - from ..activations import Sigmoid +def test_sigmoid_grad(N=15): + from numpy_ml.neural_nets.activations import Sigmoid + + np.random.seed(12345) N = np.inf if N is None else N mine = Sigmoid() - gold = torch_gradient_generator(F.sigmoid) + gold = torch_gradient_generator(torch.sigmoid) i = 0 while i < N: @@ -648,8 +523,10 @@ def test_sigmoid_grad(N=None): i += 1 -def test_elu_grad(N=None): - from ..activations import ELU +def test_elu_grad(N=15): + from numpy_ml.neural_nets.activations import ELU + + np.random.seed(12345) N = np.inf if N is None else N @@ -662,18 +539,20 @@ def test_elu_grad(N=None): mine = ELU(alpha) gold = torch_gradient_generator(F.elu, alpha=alpha) - assert_almost_equal(mine.grad(z), gold(z)) + assert_almost_equal(mine.grad(z), gold(z), decimal=5) print("PASSED") i += 1 -def test_tanh_grad(N=None): - from ..activations import Tanh +def test_tanh_grad(N=15): + from numpy_ml.neural_nets.activations import Tanh + + np.random.seed(12345) N = np.inf if N is None else N mine = Tanh() - gold = torch_gradient_generator(F.tanh) + gold = torch_gradient_generator(torch.tanh) i = 0 while i < N: @@ -685,8 +564,10 @@ def test_tanh_grad(N=None): i += 1 -def test_relu_grad(N=None): - from ..activations import ReLU +def test_relu_grad(N=15): + from numpy_ml.neural_nets.activations import ReLU + + np.random.seed(12345) N = np.inf if N is None else N @@ -703,8 +584,8 @@ def test_relu_grad(N=None): i += 1 -def test_softmax_grad(N=None): - from ..layers import Softmax +def test_softmax_grad(N=15): + from numpy_ml.neural_nets.layers import Softmax from functools import partial np.random.seed(12345) @@ -733,8 +614,10 @@ def test_softmax_grad(N=None): i += 1 -def test_softplus_grad(N=None): - from ..activations import SoftPlus +def test_softplus_grad(N=15): + from numpy_ml.neural_nets.activations import SoftPlus + + np.random.seed(12345) N = np.inf if N is None else N @@ -756,9 +639,11 @@ def test_softplus_grad(N=None): ####################################################################### -def test_FullyConnected(N=None): - from ..layers import FullyConnected - from ..activations import Tanh, ReLU, Sigmoid, Affine +def test_FullyConnected(N=15): + from numpy_ml.neural_nets.layers import FullyConnected + from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine + + np.random.seed(12345) N = np.inf if N is None else N @@ -813,8 +698,10 @@ def test_FullyConnected(N=None): i += 1 -def test_Embedding(N=None): - from ..layers import Embedding +def test_Embedding(N=15): + from numpy_ml.neural_nets.layers import Embedding + + np.random.seed(12345) N = np.inf if N is None else N @@ -860,8 +747,10 @@ def test_Embedding(N=None): i += 1 -def test_BatchNorm1D(N=None): - from ..layers import BatchNorm1D +def test_BatchNorm1D(N=15): + from numpy_ml.neural_nets.layers import BatchNorm1D + + np.random.seed(12345) N = np.inf if N is None else N @@ -910,8 +799,8 @@ def test_BatchNorm1D(N=None): i += 1 -def test_LayerNorm1D(N=None): - from ..layers import LayerNorm1D +def test_LayerNorm1D(N=15): + from numpy_ml.neural_nets.layers import LayerNorm1D N = np.inf if N is None else N @@ -956,8 +845,8 @@ def test_LayerNorm1D(N=None): i += 1 -def test_LayerNorm2D(N=None): - from ..layers import LayerNorm2D +def test_LayerNorm2D(N=15): + from numpy_ml.neural_nets.layers import LayerNorm2D N = np.inf if N is None else N @@ -1009,9 +898,9 @@ def test_LayerNorm2D(N=None): i += 1 -def test_MultiplyLayer(N=None): - from ..layers import Multiply - from ..activations import Tanh, ReLU, Sigmoid, Affine +def test_MultiplyLayer(N=15): + from numpy_ml.neural_nets.layers import Multiply + from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine N = np.inf if N is None else N @@ -1065,9 +954,9 @@ def test_MultiplyLayer(N=None): i += 1 -def test_AddLayer(N=None): - from ..layers import Add - from ..activations import Tanh, ReLU, Sigmoid, Affine +def test_AddLayer(N=15): + from numpy_ml.neural_nets.layers import Add + from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine N = np.inf if N is None else N @@ -1121,8 +1010,8 @@ def test_AddLayer(N=None): i += 1 -def test_BatchNorm2D(N=None): - from ..layers import BatchNorm2D +def test_BatchNorm2D(N=15): + from numpy_ml.neural_nets.layers import BatchNorm2D N = np.inf if N is None else N @@ -1177,8 +1066,8 @@ def test_BatchNorm2D(N=None): i += 1 -def test_RNNCell(N=None): - from ..layers import RNNCell +def test_RNNCell(N=15): + from numpy_ml.neural_nets.layers import RNNCell N = np.inf if N is None else N @@ -1240,9 +1129,9 @@ def test_RNNCell(N=None): i += 1 -def test_Conv2D(N=None): - from ..layers import Conv2D - from ..activations import Tanh, ReLU, Sigmoid, Affine +def test_Conv2D(N=15): + from numpy_ml.neural_nets.layers import Conv2D + from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine N = np.inf if N is None else N @@ -1326,8 +1215,8 @@ def test_Conv2D(N=None): i += 1 -def test_DPAttention(N=None): - from ..layers import DotProductAttention +def test_DPAttention(N=15): + from numpy_ml.neural_nets.layers import DotProductAttention N = np.inf if N is None else N @@ -1377,9 +1266,9 @@ def test_DPAttention(N=None): i += 1 -def test_Conv1D(N=None): - from ..layers import Conv1D - from ..activations import Tanh, ReLU, Sigmoid, Affine +def test_Conv1D(N=15): + from numpy_ml.neural_nets.layers import Conv1D + from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine N = np.inf if N is None else N @@ -1458,9 +1347,9 @@ def test_Conv1D(N=None): i += 1 -def test_Deconv2D(N=None): - from ..layers import Deconv2D - from ..activations import Tanh, ReLU, Sigmoid, Affine +def test_Deconv2D(N=15): + from numpy_ml.neural_nets.layers import Deconv2D + from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine N = np.inf if N is None else N @@ -1540,8 +1429,8 @@ def test_Deconv2D(N=None): i += 1 -def test_Pool2D(N=None): - from ..layers import Pool2D +def test_Pool2D(N=15): + from numpy_ml.neural_nets.layers import Pool2D N = np.inf if N is None else N @@ -1592,8 +1481,8 @@ def test_Pool2D(N=None): i += 1 -def test_LSTMCell(N=None): - from ..layers import LSTMCell +def test_LSTMCell(N=15): + from numpy_ml.neural_nets.layers import LSTMCell N = np.inf if N is None else N @@ -1716,8 +1605,8 @@ def grad_check_RNN(model, loss_func, param_name, n_t, X, epsilon=1e-7): ####################################################################### -def test_MultiHeadedAttentionModule(N=None): - from ..modules import MultiHeadedAttentionModule +def test_MultiHeadedAttentionModule(N=15): + from numpy_ml.neural_nets.modules import MultiHeadedAttentionModule N = np.inf if N is None else N np.random.seed(12345) @@ -1802,9 +1691,9 @@ def test_MultiHeadedAttentionModule(N=None): i += 1 -def test_SkipConnectionIdentityModule(N=None): - from ..modules import SkipConnectionIdentityModule - from ..activations import Tanh, ReLU, Sigmoid, Affine +def test_SkipConnectionIdentityModule(N=15): + from numpy_ml.neural_nets.modules import SkipConnectionIdentityModule + from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine N = np.inf if N is None else N @@ -1932,9 +1821,9 @@ def test_SkipConnectionIdentityModule(N=None): i += 1 -def test_SkipConnectionConvModule(N=None): - from ..modules import SkipConnectionConvModule - from ..activations import Tanh, ReLU, Sigmoid, Affine +def test_SkipConnectionConvModule(N=15): + from numpy_ml.neural_nets.modules import SkipConnectionConvModule + from numpy_ml.neural_nets.activations import Tanh, ReLU, Sigmoid, Affine N = np.inf if N is None else N @@ -2100,8 +1989,8 @@ def test_SkipConnectionConvModule(N=None): i += 1 -def test_BidirectionalLSTM(N=None): - from ..modules import BidirectionalLSTM +def test_BidirectionalLSTM(N=15): + from numpy_ml.neural_nets.modules import BidirectionalLSTM N = np.inf if N is None else N @@ -2182,8 +2071,8 @@ def test_BidirectionalLSTM(N=None): i += 1 -def test_WaveNetModule(N=None): - from ..modules import WavenetResidualModule +def test_WaveNetModule(N=10): + from numpy_ml.neural_nets.modules import WavenetResidualModule N = np.inf if N is None else N @@ -2284,9 +2173,11 @@ def test_WaveNetModule(N=None): ####################################################################### -def test_pad1D(N=None): - from ..layers import Conv1D - from .torch_models import TorchCausalConv1d, torchify +def test_pad1D(N=15): + from numpy_ml.neural_nets.layers import Conv1D + from .nn_torch_models import TorchCausalConv1d, torchify + + np.random.seed(12345) N = np.inf if N is None else N @@ -2381,7 +2272,8 @@ def test_pad1D(N=None): i += 1 -def test_conv(N=None): +def test_conv(N=15): + np.random.seed(12345) N = np.inf if N is None else N i = 0 while i < N: @@ -2414,10 +2306,11 @@ def test_conv(N=None): ####################################################################### -def test_VAE(): +def fit_VAE(): # for testing - from keras.datasets import mnist - from ..models.vae import BernoulliVAE + from numpy_ml.neural_nets.models.vae import BernoulliVAE + + np.random.seed(12345) (X_train, y_train), (X_test, y_test) = mnist.load_data() @@ -2425,14 +2318,16 @@ def test_VAE(): X_train = np.expand_dims(X_train.astype("float32") / 255.0, 3) X_test = np.expand_dims(X_test.astype("float32") / 255.0, 3) - X_train = X_train[: 128 * 10] + X_train = X_train[: 128 * 1] # 1 batch BV = BernoulliVAE() - BV.fit(X_train, verbose=True) + BV.fit(X_train, n_epochs=1, verbose=False) def test_WGAN_GP(N=1): - from ..models.wgan_gp import WGAN_GP + from numpy_ml.neural_nets.models.wgan_gp import WGAN_GP + + np.random.seed(12345) ss = np.random.randint(0, 1000) np.random.seed(ss) diff --git a/numpy_ml/neural_nets/activations/tests.py b/numpy_ml/tests/test_nn_activations.py similarity index 74% rename from numpy_ml/neural_nets/activations/tests.py rename to numpy_ml/tests/test_nn_activations.py index 18d799d..99bb294 100644 --- a/numpy_ml/neural_nets/activations/tests.py +++ b/numpy_ml/tests/test_nn_activations.py @@ -1,4 +1,4 @@ -import sys +# flake8: noqa import time import numpy as np @@ -8,8 +8,7 @@ import torch import torch.nn.functional as F -sys.path.append("../..") -from utils.testing import random_stochastic_matrix, random_tensor +from numpy_ml.utils.testing import random_stochastic_matrix, random_tensor def torch_gradient_generator(fn, **kwargs): @@ -46,56 +45,56 @@ def err_fmt(params, golds, ix, warn_str=""): ####################################################################### # Test Suite # ####################################################################### - - -def test_activations(N=50): - print("Testing Sigmoid activation") - time.sleep(1) - test_sigmoid_activation(N) - test_sigmoid_grad(N) - - # print("Testing Softmax activation") - # time.sleep(1) - # test_softmax_activation(N) - # test_softmax_grad(N) - - print("Testing Tanh activation") - time.sleep(1) - test_tanh_grad(N) - - print("Testing ReLU activation") - time.sleep(1) - test_relu_activation(N) - test_relu_grad(N) - - print("Testing ELU activation") - time.sleep(1) - test_elu_activation(N) - test_elu_grad(N) - - print("Testing SELU activation") - time.sleep(1) - test_selu_activation(N) - test_selu_grad(N) - - print("Testing LeakyRelu activation") - time.sleep(1) - test_leakyrelu_activation(N) - test_leakyrelu_grad(N) - - print("Testing SoftPlus activation") - time.sleep(1) - test_softplus_activation(N) - test_softplus_grad(N) - +# +# +# def test_activations(N=50): +# print("Testing Sigmoid activation") +# time.sleep(1) +# test_sigmoid_activation(N) +# test_sigmoid_grad(N) +# +# # print("Testing Softmax activation") +# # time.sleep(1) +# # test_softmax_activation(N) +# # test_softmax_grad(N) +# +# print("Testing Tanh activation") +# time.sleep(1) +# test_tanh_grad(N) +# +# print("Testing ReLU activation") +# time.sleep(1) +# test_relu_activation(N) +# test_relu_grad(N) +# +# print("Testing ELU activation") +# time.sleep(1) +# test_elu_activation(N) +# test_elu_grad(N) +# +# print("Testing SELU activation") +# time.sleep(1) +# test_selu_activation(N) +# test_selu_grad(N) +# +# print("Testing LeakyRelu activation") +# time.sleep(1) +# test_leakyrelu_activation(N) +# test_leakyrelu_grad(N) +# +# print("Testing SoftPlus activation") +# time.sleep(1) +# test_softplus_activation(N) +# test_softplus_grad(N) +# ####################################################################### # Activations # ####################################################################### -def test_sigmoid_activation(N=None): - from activations import Sigmoid +def test_sigmoid_activation(N=50): + from numpy_ml.neural_nets.activations import Sigmoid N = np.inf if N is None else N @@ -111,8 +110,8 @@ def test_sigmoid_activation(N=None): i += 1 -def test_softplus_activation(N=None): - from activations import SoftPlus +def test_softplus_activation(N=50): + from numpy_ml.neural_nets.activations import SoftPlus N = np.inf if N is None else N @@ -128,8 +127,8 @@ def test_softplus_activation(N=None): i += 1 -def test_elu_activation(N=None): - from activations import ELU +def test_elu_activation(N=50): + from numpy_ml.neural_nets.activations import ELU N = np.inf if N is None else N @@ -148,8 +147,8 @@ def test_elu_activation(N=None): i += 1 -def test_relu_activation(N=None): - from activations import ReLU +def test_relu_activation(N=50): + from numpy_ml.neural_nets.activations import ReLU N = np.inf if N is None else N @@ -165,8 +164,8 @@ def test_relu_activation(N=None): i += 1 -def test_selu_activation(N=None): - from activations import SELU +def test_selu_activation(N=50): + from numpy_ml.neural_nets.activations import SELU N = np.inf if N is None else N @@ -182,8 +181,8 @@ def test_selu_activation(N=None): i += 1 -def test_leakyrelu_activation(N=None): - from activations import LeakyReLU +def test_leakyrelu_activation(N=50): + from numpy_ml.neural_nets.activations import LeakyReLU N = np.inf if N is None else N @@ -206,8 +205,8 @@ def test_leakyrelu_activation(N=None): ####################################################################### -def test_sigmoid_grad(N=None): - from activations import Sigmoid +def test_sigmoid_grad(N=50): + from numpy_ml.neural_nets.activations import Sigmoid N = np.inf if N is None else N @@ -224,8 +223,8 @@ def test_sigmoid_grad(N=None): i += 1 -def test_elu_grad(N=None): - from activations import ELU +def test_elu_grad(N=50): + from numpy_ml.neural_nets.activations import ELU N = np.inf if N is None else N @@ -243,8 +242,8 @@ def test_elu_grad(N=None): i += 1 -def test_tanh_grad(N=None): - from activations import Tanh +def test_tanh_grad(N=50): + from numpy_ml.neural_nets.activations import Tanh N = np.inf if N is None else N @@ -261,8 +260,8 @@ def test_tanh_grad(N=None): i += 1 -def test_relu_grad(N=None): - from activations import ReLU +def test_relu_grad(N=50): + from numpy_ml.neural_nets.activations import ReLU N = np.inf if N is None else N @@ -279,8 +278,8 @@ def test_relu_grad(N=None): i += 1 -def test_selu_grad(N=None): - from activations import SELU +def test_selu_grad(N=50): + from numpy_ml.neural_nets.activations import SELU N = np.inf if N is None else N @@ -297,8 +296,8 @@ def test_selu_grad(N=None): i += 1 -def test_leakyrelu_grad(N=None): - from activations import LeakyReLU +def test_leakyrelu_grad(N=50): + from numpy_ml.neural_nets.activations import LeakyReLU N = np.inf if N is None else N @@ -316,8 +315,8 @@ def test_leakyrelu_grad(N=None): i += 1 -def test_softplus_grad(N=None): - from activations import SoftPlus +def test_softplus_grad(N=50): + from numpy_ml.neural_nets.activations import SoftPlus N = np.inf if N is None else N diff --git a/numpy_ml/nonparametric/tests.py b/numpy_ml/tests/test_nonparametric.py similarity index 84% rename from numpy_ml/nonparametric/tests.py rename to numpy_ml/tests/test_nonparametric.py index 9c8d443..9e2ec7e 100644 --- a/numpy_ml/nonparametric/tests.py +++ b/numpy_ml/tests/test_nonparametric.py @@ -1,15 +1,19 @@ +# flake8: noqa import numpy as np from sklearn.neighbors import KNeighborsRegressor, KNeighborsClassifier from sklearn.gaussian_process import GaussianProcessRegressor -from .knn import KNN -from .gp import GPRegression -from ..utils.distance_metrics import euclidean +from numpy_ml.nonparametric.knn import KNN +from numpy_ml.nonparametric.gp import GPRegression +from numpy_ml.utils.distance_metrics import euclidean -def test_knn_regression(): - while True: +def test_knn_regression(N=15): + np.random.seed(12345) + + i = 0 + while i < N: N = np.random.randint(2, 100) M = np.random.randint(2, 100) k = np.random.randint(1, N) @@ -40,14 +44,18 @@ def test_knn_regression(): for mine, theirs in zip(preds, gold_preds): np.testing.assert_almost_equal(mine, theirs) print("PASSED") + i += 1 + +def test_knn_clf(N=15): + np.random.seed(12345) -def test_knn_clf(): - while True: + i = 0 + while i < N: N = np.random.randint(2, 100) M = np.random.randint(2, 100) k = np.random.randint(1, N) - n_classes = np.random.randint(10) + n_classes = np.random.randint(2, 10) ls = np.min([np.random.randint(1, 10), N - 1]) weights = "uniform" @@ -61,10 +69,10 @@ def test_knn_clf(): gold = KNeighborsClassifier( p=2, + metric="minkowski", leaf_size=ls, n_neighbors=k, weights=weights, - metric="minkowski", algorithm="ball_tree", ) gold.fit(X, y) @@ -73,10 +81,14 @@ def test_knn_clf(): for mine, theirs in zip(preds, gold_preds): np.testing.assert_almost_equal(mine, theirs) print("PASSED") + i += 1 + +def test_gp_regression(N=15): + np.random.seed(12345) -def test_gp_regression(): - while True: + i = 0 + while i < N: alpha = np.random.rand() N = np.random.randint(2, 100) M = np.random.randint(2, 100) @@ -104,3 +116,4 @@ def test_gp_regression(): np.testing.assert_almost_equal(mll, gold_mll) print("PASSED") + i += 1 diff --git a/numpy_ml/preprocessing/tests.py b/numpy_ml/tests/test_preprocessing.py similarity index 79% rename from numpy_ml/preprocessing/tests.py rename to numpy_ml/tests/test_preprocessing.py index 333cc9a..793e31c 100644 --- a/numpy_ml/preprocessing/tests.py +++ b/numpy_ml/tests/test_preprocessing.py @@ -1,3 +1,4 @@ +# flake8: noqa from collections import Counter # gold-standard imports @@ -15,14 +16,24 @@ from librosa.filters import mel # numpy-ml implementations -from .general import Standardizer -from .nlp import HuffmanEncoder, TFIDFEncoder -from .dsp import DCT, DFT, mfcc, to_frames, mel_filterbank, dft_bins -from ..utils.testing import random_paragraph - - -def test_huffman(): - while True: +from numpy_ml.preprocessing.general import Standardizer +from numpy_ml.preprocessing.nlp import HuffmanEncoder, TFIDFEncoder +from numpy_ml.preprocessing.dsp import ( + DCT, + DFT, + mfcc, + to_frames, + mel_filterbank, + dft_bins, +) +from numpy_ml.utils.testing import random_paragraph + + +def test_huffman(N=15): + np.random.seed(12345) + + i = 0 + while i < N: n_words = np.random.randint(1, 100) para = random_paragraph(n_words) HT = HuffmanEncoder() @@ -35,10 +46,14 @@ def test_huffman(): assert k in my_dict, "key `{}` not in my_dict".format(k) assert my_dict[k] == v, fstr.format(k, v, k, my_dict[k]) print("PASSED") + i += 1 + +def test_standardizer(N=15): + np.random.seed(12345) -def test_standardizer(): - while True: + i = 0 + while i < N: mean = bool(np.random.randint(2)) std = bool(np.random.randint(2)) N = np.random.randint(2, 100) @@ -54,10 +69,14 @@ def test_standardizer(): np.testing.assert_almost_equal(mine, gold) print("PASSED") + i += 1 -def test_tfidf(): - while True: +def test_tfidf(N=15): + np.random.seed(12345) + + i = 0 + while i < N: docs = [] n_docs = np.random.randint(1, 10) for d in range(n_docs): @@ -90,10 +109,14 @@ def test_tfidf(): np.testing.assert_almost_equal(mine, theirs) print("PASSED") + i += 1 + +def test_dct(N=15): + np.random.seed(12345) -def test_dct(): - while True: + i = 0 + while i < N: N = np.random.randint(2, 100) signal = np.random.rand(N) ortho = bool(np.random.randint(2)) @@ -102,10 +125,14 @@ def test_dct(): np.testing.assert_almost_equal(mine, theirs) print("PASSED") + i += 1 + +def test_dft(N=15): + np.random.seed(12345) -def test_dft(): - while True: + i = 0 + while i < N: N = np.random.randint(2, 100) signal = np.random.rand(N) mine = DFT(signal) @@ -113,13 +140,17 @@ def test_dft(): np.testing.assert_almost_equal(mine.real, theirs.real) print("PASSED") + i += 1 -def test_mfcc(): +def test_mfcc(N=1): """Broken""" - while True: - N = np.random.randint(500, 100000) - fs = np.random.randint(50, 10000) + np.random.seed(12345) + + i = 0 + while i < N: + N = np.random.randint(500, 1000) + fs = np.random.randint(50, 100) n_mfcc = 12 window_len = 100 stride_len = 50 @@ -127,8 +158,6 @@ def test_mfcc(): window_dur = window_len / fs stride_dur = stride_len / fs signal = np.random.rand(N) - # ff = frame(signal, frame_length=window_len, hop_length=stride_len).T - # print(len(ff)) mine = mfcc( signal, @@ -155,12 +184,16 @@ def test_mfcc(): htk=True, ).T - np.testing.assert_almost_equal(mine, theirs, decimal=5) + np.testing.assert_almost_equal(mine, theirs, decimal=4) print("PASSED") + i += 1 -def test_framing(): - while True: +def test_framing(N=15): + np.random.seed(12345) + + i = 0 + while i < N: N = np.random.randint(500, 100000) window_len = np.random.randint(10, 100) stride_len = np.random.randint(1, 50) @@ -174,10 +207,14 @@ def test_framing(): ) np.testing.assert_almost_equal(mine, theirs) print("PASSED") + i += 1 + +def test_dft_bins(N=15): + np.random.seed(12345) -def test_dft_bins(): - while True: + i = 0 + while i < N: N = np.random.randint(500, 100000) fs = np.random.randint(50, 1000) @@ -185,10 +222,14 @@ def test_dft_bins(): theirs = fft_frequencies(fs, N) np.testing.assert_almost_equal(mine, theirs) print("PASSED") + i += 1 + +def test_mel_filterbank(N=15): + np.random.seed(12345) -def test_mel_filterbank(): - while True: + i = 0 + while i < N: fs = np.random.randint(50, 10000) n_filters = np.random.randint(2, 20) window_len = np.random.randint(10, 100) @@ -208,3 +249,4 @@ def test_mel_filterbank(): np.testing.assert_almost_equal(mine, theirs) print("PASSED") + i += 1 diff --git a/numpy_ml/trees/tests.py b/numpy_ml/tests/test_trees.py similarity index 65% rename from numpy_ml/trees/tests.py rename to numpy_ml/tests/test_trees.py index 87130df..4a90fb5 100644 --- a/numpy_ml/trees/tests.py +++ b/numpy_ml/tests/test_trees.py @@ -1,37 +1,16 @@ +# flake8: noqa import numpy as np from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.metrics import accuracy_score, mean_squared_error -from sklearn.datasets import make_regression -from sklearn.datasets.samples_generator import make_blobs +from sklearn.datasets import make_regression, make_blobs from sklearn.model_selection import train_test_split -import matplotlib - -matplotlib.use("TkAgg") -import matplotlib.pyplot as plt - -# https://seaborn.pydata.org/generated/seaborn.set_context.html -# https://seaborn.pydata.org/generated/seaborn.set_style.html -import seaborn as sns - -sns.set_style("white") -sns.set_context("paper", font_scale=0.9) - -from .gbdt import GradientBoostedDecisionTree -from .dt import DecisionTree, Node, Leaf -from .rf import RandomForest - - -def random_tensor(shape, standardize=False): - eps = np.finfo(float).eps - offset = np.random.randint(-300, 300, shape) - X = np.random.rand(*shape) + offset - - if standardize: - X = (X - X.mean(axis=0)) / (X.std(axis=0) + eps) - return X +from numpy_ml.trees.gbdt import GradientBoostedDecisionTree +from numpy_ml.trees.dt import DecisionTree, Node, Leaf +from numpy_ml.trees.rf import RandomForest +from numpy_ml.utils.testing import random_tensor def clone_tree(dtree): @@ -88,10 +67,10 @@ def test(mine, clone): raise ValueError("Nodes at depth {} are not equal".format(depth)) -def test_DecisionTree(): +def test_DecisionTree(N=1): i = 1 np.random.seed(12345) - while True: + while i <= N: n_ex = np.random.randint(2, 100) n_feats = np.random.randint(2, 100) max_depth = np.random.randint(1, 5) @@ -172,10 +151,10 @@ def loss(yp, y): i += 1 -def test_RandomForest(): +def test_RandomForest(N=1): np.random.seed(12345) i = 1 - while True: + while i <= N: n_ex = np.random.randint(2, 100) n_feats = np.random.randint(2, 100) n_trees = np.random.randint(2, 100) @@ -273,10 +252,10 @@ def loss(yp, y): i += 1 -def test_gbdt(): +def test_gbdt(N=1): np.random.seed(12345) i = 1 - while True: + while i <= N: n_ex = np.random.randint(2, 100) n_feats = np.random.randint(2, 100) n_trees = np.random.randint(2, 100) @@ -298,13 +277,12 @@ def loss(yp, y): # initialize model criterion = np.random.choice(["entropy", "gini"]) mine = GradientBoostedDecisionTree( + n_iter=n_trees, classifier=classifier, - n_trees=n_trees, max_depth=max_depth, learning_rate=0.1, loss="crossentropy", step_size="constant", - split_criterion=criterion, ) gold = RandomForestClassifier( n_estimators=n_trees, @@ -322,13 +300,12 @@ def loss(yp, y): criterion = "mse" loss = mean_squared_error mine = GradientBoostedDecisionTree( - n_trees=n_trees, + n_iter=n_trees, max_depth=max_depth, classifier=classifier, learning_rate=0.1, loss="mse", step_size="constant", - split_criterion=criterion, ) gold = RandomForestRegressor( n_estimators=n_trees, @@ -376,147 +353,3 @@ def loss(yp, y): print("PASSED") i += 1 - - -def plot(): - fig, axes = plt.subplots(4, 4) - fig.set_size_inches(10, 10) - for ax in axes.flatten(): - n_ex = 100 - n_trees = 50 - n_feats = np.random.randint(2, 100) - max_depth_d = np.random.randint(1, 100) - max_depth_r = np.random.randint(1, 10) - - classifier = np.random.choice([True, False]) - if classifier: - # create classification problem - n_classes = np.random.randint(2, 10) - X, Y = make_blobs(n_samples=n_ex, centers=n_classes, n_features=2) - X, X_test, Y, Y_test = train_test_split(X, Y, test_size=0.3) - n_feats = min(n_feats, X.shape[1]) - - # initialize model - def loss(yp, y): - return accuracy_score(yp, y) - - # initialize model - criterion = np.random.choice(["entropy", "gini"]) - mine = RandomForest( - classifier=classifier, - n_feats=n_feats, - n_trees=n_trees, - criterion=criterion, - max_depth=max_depth_r, - ) - mine_d = DecisionTree( - criterion=criterion, max_depth=max_depth_d, classifier=classifier - ) - mine_g = GradientBoostedDecisionTree( - n_trees=n_trees, - max_depth=max_depth_d, - classifier=classifier, - learning_rate=1, - loss="crossentropy", - step_size="constant", - split_criterion=criterion, - ) - - else: - # create regeression problem - X, Y = make_regression(n_samples=n_ex, n_features=1) - X, X_test, Y, Y_test = train_test_split(X, Y, test_size=0.3) - n_feats = min(n_feats, X.shape[1]) - - # initialize model - criterion = "mse" - loss = mean_squared_error - mine = RandomForest( - criterion=criterion, - n_feats=n_feats, - n_trees=n_trees, - max_depth=max_depth_r, - classifier=classifier, - ) - mine_d = DecisionTree( - criterion=criterion, max_depth=max_depth_d, classifier=classifier - ) - mine_g = GradientBoostedDecisionTree( - n_trees=n_trees, - max_depth=max_depth_d, - classifier=classifier, - learning_rate=1, - loss="mse", - step_size="adaptive", - split_criterion=criterion, - ) - - # fit 'em - mine.fit(X, Y) - mine_d.fit(X, Y) - mine_g.fit(X, Y) - - # get preds on test set - y_pred_mine_test = mine.predict(X_test) - y_pred_mine_test_d = mine_d.predict(X_test) - y_pred_mine_test_g = mine_g.predict(X_test) - - loss_mine_test = loss(y_pred_mine_test, Y_test) - loss_mine_test_d = loss(y_pred_mine_test_d, Y_test) - loss_mine_test_g = loss(y_pred_mine_test_g, Y_test) - - if classifier: - entries = [ - ("RF", loss_mine_test, y_pred_mine_test), - ("DT", loss_mine_test_d, y_pred_mine_test_d), - ("GB", loss_mine_test_g, y_pred_mine_test_g), - ] - (lbl, test_loss, preds) = entries[np.random.randint(3)] - ax.set_title("{} Accuracy: {:.2f}%".format(lbl, test_loss * 100)) - for i in np.unique(Y_test): - ax.scatter( - X_test[preds == i, 0].flatten(), - X_test[preds == i, 1].flatten(), - # s=0.5, - ) - else: - X_ax = np.linspace( - np.min(X_test.flatten()) - 1, np.max(X_test.flatten()) + 1, 100 - ).reshape(-1, 1) - y_pred_mine_test = mine.predict(X_ax) - y_pred_mine_test_d = mine_d.predict(X_ax) - y_pred_mine_test_g = mine_g.predict(X_ax) - - ax.scatter(X_test.flatten(), Y_test.flatten(), c="b", alpha=0.5) - # s=0.5) - ax.plot( - X_ax.flatten(), - y_pred_mine_test_g.flatten(), - # linewidth=0.5, - label="GB".format(n_trees, n_feats, max_depth_d), - color="red", - ) - ax.plot( - X_ax.flatten(), - y_pred_mine_test.flatten(), - # linewidth=0.5, - label="RF".format(n_trees, n_feats, max_depth_r), - color="cornflowerblue", - ) - ax.plot( - X_ax.flatten(), - y_pred_mine_test_d.flatten(), - # linewidth=0.5, - label="DT".format(max_depth_d), - color="yellowgreen", - ) - ax.set_title( - "GB: {:.1f} / RF: {:.1f} / DT: {:.1f} ".format( - loss_mine_test_g, loss_mine_test, loss_mine_test_d - ) - ) - ax.legend() - ax.xaxis.set_ticklabels([]) - ax.yaxis.set_ticklabels([]) - plt.savefig("plot.png", dpi=300) - plt.close("all") diff --git a/numpy_ml/utils/tests.py b/numpy_ml/tests/test_utils.py similarity index 83% rename from numpy_ml/utils/tests.py rename to numpy_ml/tests/test_utils.py index 1089b0f..7721c99 100644 --- a/numpy_ml/utils/tests.py +++ b/numpy_ml/tests/test_utils.py @@ -1,3 +1,4 @@ +# flake8: noqa import numpy as np import scipy @@ -9,20 +10,26 @@ from sklearn.metrics.pairwise import polynomial_kernel as sk_poly -from .distance_metrics import euclidean -from .kernels import LinearKernel, PolynomialKernel, RBFKernel -from .data_structures import BallTree -from .graphs import DiGraph, UndirectedGraph, Edge, random_unweighted_graph, random_DAG +from numpy_ml.utils.distance_metrics import euclidean +from numpy_ml.utils.kernels import LinearKernel, PolynomialKernel, RBFKernel +from numpy_ml.utils.data_structures import BallTree +from numpy_ml.utils.graphs import ( + DiGraph, + UndirectedGraph, + Edge, + random_unweighted_graph, + random_DAG, +) ####################################################################### # Kernels # ####################################################################### -def test_linear_kernel(): +def test_linear_kernel(N=1): np.random.seed(12345) - - while True: + i = 0 + while i < N: N = np.random.randint(1, 100) M = np.random.randint(1, 100) C = np.random.randint(1, 1000) @@ -35,12 +42,13 @@ def test_linear_kernel(): np.testing.assert_almost_equal(mine, gold) print("PASSED") + i += 1 -def test_polynomial_kernel(): +def test_polynomial_kernel(N=1): np.random.seed(12345) - - while True: + i = 0 + while i < N: N = np.random.randint(1, 100) M = np.random.randint(1, 100) C = np.random.randint(1, 1000) @@ -56,12 +64,13 @@ def test_polynomial_kernel(): np.testing.assert_almost_equal(mine, gold) print("PASSED") + i += 1 -def test_radial_basis_kernel(): +def test_radial_basis_kernel(N=1): np.random.seed(12345) - - while True: + i = 0 + while i < N: N = np.random.randint(1, 100) M = np.random.randint(1, 100) C = np.random.randint(1, 1000) @@ -79,6 +88,7 @@ def test_radial_basis_kernel(): np.testing.assert_almost_equal(mine, gold) print("PASSED") + i += 1 ####################################################################### @@ -86,10 +96,10 @@ def test_radial_basis_kernel(): ####################################################################### -def test_euclidean(): +def test_euclidean(N=1): np.random.seed(12345) - - while True: + i = 0 + while i < N: N = np.random.randint(1, 100) x = np.random.rand(N) y = np.random.rand(N) @@ -97,6 +107,7 @@ def test_euclidean(): theirs = scipy.spatial.distance.euclidean(x, y) np.testing.assert_almost_equal(mine, theirs) print("PASSED") + i += 1 ####################################################################### @@ -104,10 +115,10 @@ def test_euclidean(): ####################################################################### -def test_ball_tree(): +def test_ball_tree(N=1): np.random.seed(12345) - - while True: + i = 0 + while i < N: N = np.random.randint(2, 100) M = np.random.randint(2, 100) k = np.random.randint(1, N) @@ -135,11 +146,12 @@ def test_ball_tree(): theirs_dist = theirs_dist.flatten()[sort_ix] theirs_neighb = X[ind.flatten()[sort_ix]] - for i in range(len(theirs_dist)): - np.testing.assert_almost_equal(mine_neighb[i], theirs_neighb[i]) - np.testing.assert_almost_equal(mine_dist[i], theirs_dist[i]) + for j in range(len(theirs_dist)): + np.testing.assert_almost_equal(mine_neighb[j], theirs_neighb[j]) + np.testing.assert_almost_equal(mine_dist[j], theirs_dist[j]) print("PASSED") + i += 1 ####################################################################### @@ -148,7 +160,7 @@ def test_ball_tree(): def from_networkx(G_nx): - """ Convert a networkx graph to my graph representation""" + """Convert a networkx graph to my graph representation""" V = list(G_nx.nodes) edges = list(G_nx.edges) is_weighted = "weight" in G_nx[edges[0][0]][edges[0][1]] @@ -178,13 +190,13 @@ def to_networkx(G): return G_nx -def test_all_paths(): +def test_all_paths(N=1): np.random.seed(12345) - - while True: + i = 0 + while i < N: p = np.random.rand() directed = np.random.rand() < 0.5 - G = random_unweighted_graph(n_vertices=10, edge_prob=p, directed=directed) + G = random_unweighted_graph(n_vertices=5, edge_prob=p, directed=directed) nodes = G._I2V.keys() G_nx = to_networkx(G) @@ -207,12 +219,13 @@ def test_all_paths(): np.testing.assert_array_equal(p1, p2) print("PASSED") + i += 1 -def test_random_DAG(): +def test_random_DAG(N=1): np.random.seed(12345) - - while True: + i = 0 + while i < N: p = np.random.uniform(0.25, 1) n_v = np.random.randint(5, 50) @@ -221,12 +234,13 @@ def test_random_DAG(): assert nx.is_directed_acyclic_graph(G_nx) print("PASSED") + i += 1 -def test_topological_ordering(): +def test_topological_ordering(N=1): np.random.seed(12345) - - while True: + i = 0 + while i < N: p = np.random.uniform(0.25, 1) n_v = np.random.randint(5, 10) @@ -243,12 +257,13 @@ def test_topological_ordering(): assert any([c_i in seen_it for c_i in G.get_neighbors(n_i)]) == False print("PASSED") + i += 1 -def test_is_acyclic(): +def test_is_acyclic(N=1): np.random.seed(12345) - - while True: + i = 0 + while i < N: p = np.random.rand() directed = np.random.rand() < 0.5 G = random_unweighted_graph(n_vertices=10, edge_prob=p, directed=True) @@ -256,3 +271,4 @@ def test_is_acyclic(): assert G.is_acyclic() == nx.is_directed_acyclic_graph(G_nx) print("PASSED") + i += 1 diff --git a/numpy_ml/utils/testing.py b/numpy_ml/utils/testing.py index 0d45dbf..c56d395 100644 --- a/numpy_ml/utils/testing.py +++ b/numpy_ml/utils/testing.py @@ -1,3 +1,4 @@ +"""Utilities for writing unit tests""" import numbers import numpy as np @@ -13,9 +14,7 @@ def is_symmetric(X): def is_symmetric_positive_definite(X): - """ - Check that a matrix `X` is a symmetric and positive-definite. - """ + """Check that a matrix `X` is a symmetric and positive-definite.""" if is_symmetric(X): try: # if matrix is symmetric, check whether the Cholesky decomposition @@ -133,3 +132,12 @@ def random_paragraph(n_words, vocab=None): "gubergren", ] return [np.random.choice(vocab) for _ in range(n_words)] + + +####################################################################### +# Custom Warnings # +####################################################################### + + +class DependencyWarning(RuntimeWarning): + pass diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..6c9f229 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,17 @@ +numpy +scipy +sklearn +torch +networkx +matplotlib +seaborn +tensorflow +gym +keras +huffman +librosa +nltk +hmmlearn +pre-commit +tox +pytest diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..40de83f --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,14 @@ +numpy +scipy +sklearn +torch +networkx +tensorflow +keras +gym +huffman +librosa +nltk +hmmlearn +tox +pytest diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6bad103 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +numpy +scipy diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..45ff20b --- /dev/null +++ b/setup.py @@ -0,0 +1,43 @@ +# flake8: noqa +from codecs import open + +from setuptools import setup, find_packages + +with open("README.md", encoding="utf-8") as f: + LONG_DESCRIPTION = f.read() + +with open("requirements.txt") as requirements: + REQUIREMENTS = [r.strip() for r in requirements if r != "\n"] + +PROJECT_URLS = { + "Bug Tracker": "https://github.com/ddbourgin/numpy-ml/issues", + "Documentation": "https://numpy-ml.readthedocs.io/en/latest/", + "Source": "https://github.com/ddbourgin/numpy-ml", +} + +setup( + name="numpy-ml", + version="0.1.2", + author="David Bourgin", + author_email="ddbourgin@gmail.com", + project_urls=PROJECT_URLS, + url="https://github.com/ddbourgin/numpy-ml", + description="Machine learning in NumPy", + long_description=LONG_DESCRIPTION, + long_description_content_type="text/markdown", + install_requires=REQUIREMENTS, + packages=find_packages(), + license="GPLv3+", + include_package_data=True, + python_requires=">=3.5", + extras_require={"rl": ["gym", "matplotlib"]}, + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "Topic :: Scientific/Engineering", + "License :: OSI Approved :: GNU General Public License (GPL)", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + ], +) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..0c65edf --- /dev/null +++ b/tox.ini @@ -0,0 +1,6 @@ +[tox] +envlist = py36,py38 +skip_missing_interpreters=true +[testenv] +deps = -rrequirements-test.txt +commands = pytest