From 07e6c5efff0e6e96394fb7b68f5cad7044574f39 Mon Sep 17 00:00:00 2001 From: dan Date: Thu, 29 Aug 2024 11:19:36 -0700 Subject: [PATCH] add a test for Theta.pop() --- sharktank/tests/types/dataset_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sharktank/tests/types/dataset_test.py b/sharktank/tests/types/dataset_test.py index 82c9723f0..99d176c5b 100644 --- a/sharktank/tests/types/dataset_test.py +++ b/sharktank/tests/types/dataset_test.py @@ -77,6 +77,22 @@ def testTransform(self): self.assertIsNot(pt1, pt2) torch.testing.assert_close(pt1, pt2) + def testPop(self): + t1 = Theta( + _flat_t_dict( + _t("a.b.c", 1, 2), + _t("a.c.d", 10, 11), + _t("a.b.3", 3, 4), + ) + ) + popped = t1.pop("a.b").flatten() + t1 = t1.flatten() + + self.assertIsNotNone("a.c.d", t1.keys()) + self.assertNotIn("a.b.c", t1.keys()) + self.assertNotIn("a.b.3", t1.keys()) + self.assertIn("a.b.3", popped.keys()) + class DatasetTest(unittest.TestCase): def setUp(self):