diff --git a/src/gtnet/sequence.py b/src/gtnet/sequence.py index f3d15f1..5489486 100644 --- a/src/gtnet/sequence.py +++ b/src/gtnet/sequence.py @@ -34,6 +34,8 @@ def __init__(self, window, step, vocab=None, padval=None, min_seq_len=100, devic self.device = device def encode(self, seq): + if len(seq) < self.min_seq_len: + raise ValueError(f"Minimum sequence length is {self.min_seq_len} - got {len(seq)}") if seq.dtype == np.dtype('S1'): seq = seq.view(np.uint8) elif seq.dtype == np.dtype('U1'): @@ -143,9 +145,14 @@ def readfiles(cls, encoder, fastas): for fa in fastas: logging.debug(f'loading {fa}') for seqid, values in cls.readfile(fa): - batches = encoder.encode(values) - val = (fa, seqid, len(values), batches) - yield val + if len(values) < encoder.min_seq_len: + logging.warning((f"Skipping {seqid} from {fa} - length less than " + "minimum sequence length {encoder.min_seq}")) + yield (fa, seqid, len(values), torch.zeros((0, 0, 0), dtype=torch.uint8)) + else: + batches = encoder.encode(values) + val = (fa, seqid, len(values), batches) + yield val class SerialLoader(Loader):