From 5d726207f25ce000c4de5ec086a0e18f337dfe36 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Jan 2024 19:10:38 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lai_tldr/data.py | 11 +++++------ lai_tldr/module.py | 11 +++++------ tests/test_data.py | 1 - 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/lai_tldr/data.py b/lai_tldr/data.py index 81299b3..1f97738 100644 --- a/lai_tldr/data.py +++ b/lai_tldr/data.py @@ -19,7 +19,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - """Code in this file is based on https://github.com/Shivanandroy/simpleT5 by Shivanand Roy.""" import os @@ -55,11 +54,11 @@ def __init__( self.target_max_token_len = target_max_token_len def __len__(self): - """returns length of data.""" + """Returns length of data.""" return len(self.data) def __getitem__(self, index: int): - """returns dictionary of input tensors to feed into T5/MT5 model.""" + """Returns dictionary of input tensors to feed into T5/MT5 model.""" data_row = self.data.iloc[index] source_text = data_row["source_text"] @@ -170,7 +169,7 @@ def setup(self, stage=None): ) def train_dataloader(self): - """training dataloader.""" + """Training dataloader.""" return DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -179,7 +178,7 @@ def train_dataloader(self): ) def test_dataloader(self): - """test dataloader.""" + """Test dataloader.""" return DataLoader( self.test_dataset, batch_size=self.batch_size, @@ -188,7 +187,7 @@ def test_dataloader(self): ) def val_dataloader(self): - """validation dataloader.""" + """Validation dataloader.""" return DataLoader( self.val_dataset, batch_size=self.batch_size, diff --git a/lai_tldr/module.py b/lai_tldr/module.py index a9b9d80..4e130f9 100644 --- a/lai_tldr/module.py +++ b/lai_tldr/module.py @@ -19,7 +19,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - """Code in this file is based on https://github.com/Shivanandroy/simpleT5 by Shivanand Roy.""" from lightning.pytorch import LightningModule @@ -42,7 +41,7 @@ def __init__( self.save_only_last_epoch = False def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None): - """forward step.""" + """Forward step.""" output = self.model( input_ids, attention_mask=attention_mask, @@ -53,7 +52,7 @@ def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None return output.loss, output.logits def training_step(self, batch, batch_idx): - """training step.""" + """Training step.""" input_ids = batch["source_text_input_ids"] attention_mask = batch["source_text_attention_mask"] labels = batch["labels"] @@ -70,7 +69,7 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): - """validation step.""" + """Validation step.""" input_ids = batch["source_text_input_ids"] attention_mask = batch["source_text_attention_mask"] labels = batch["labels"] @@ -86,7 +85,7 @@ def validation_step(self, batch, batch_idx): self.log("val_loss", loss, prog_bar=True) def test_step(self, batch, batch_idx): - """test step.""" + """Test step.""" input_ids = batch["source_text_input_ids"] attention_mask = batch["source_text_attention_mask"] labels = batch["labels"] @@ -102,7 +101,7 @@ def test_step(self, batch, batch_idx): self.log("test_loss", loss, prog_bar=True) def configure_optimizers(self): - """configure optimizers.""" + """Configure optimizers.""" return AdamW(self.parameters(), lr=0.0001) diff --git a/tests/test_data.py b/tests/test_data.py index 3bae8ca..eb315db 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -28,7 +28,6 @@ def test_summarization_dataset(): counter = 0 for sample in dset: - assert isinstance(sample, dict) keys = list(sample.keys())