diff --git a/box/box.go b/box/box.go new file mode 100644 index 0000000..71bb116 --- /dev/null +++ b/box/box.go @@ -0,0 +1,236 @@ +package box + +import ( + "sort" + "strings" +) + +// Box is a data structure representing a box in an image, +// with x and y float coordinates, and the text inside the box. +type Box struct { + XLeft float64 + XRight float64 + YBottom float64 + YTop float64 + Content string +} + +// Inside other box o if it is completely inside, +// i.e. all coordinates for the outer box are more extreme or overlap +func (b Box) Inside(o Box) bool { + return o.XLeft <= b.XLeft && o.XRight >= b.XRight && o.YTop <= b.YTop && o.YBottom >= b.YBottom +} + +// Box's x coordinates overlap with the region of left and right +func (b Box) XOverlap(left float64, right float64) bool { + // box is to the left + if b.XRight < left { + return false + } + // box is to the right + if b.XLeft > right { + return false + } + return true +} + +// Box's y coordinates overlap with the region of top and bottom +func (b Box) YOverlap(top float64, bottom float64) bool { + // box is above + if b.YBottom < top { + return false + } + // box is below + if b.YTop > bottom { + return false + } + return true +} + +// Find all non-overlapping regions in x direction of coordinates +// where there is at least one box. +func XRegions(boxes []Box) [][]float64 { + regions := make([][]float64, 0) + for _, b := range boxes { + found := false + for _, region := range regions { + left := region[0] + right := region[1] + if b.XOverlap(left, right) { + region[0] = min(left, b.XLeft) + region[1] = max(right, b.XRight) + found = true + break + } + } + if !found { + regions = append(regions, []float64{b.XLeft, b.XRight}) + } + } + return mergeRegions(regions) +} + +// Find all non-overlapping regions in y direction of coordinates +// where there is at least one box. +func YRegions(boxes []Box) [][]float64 { + regions := make([][]float64, 0) + // iterate over all boxes + for _, b := range boxes { + // overlap = has overlap with existing region + overlap := false + for _, region := range regions { + top := region[0] + bottom := region[1] + // box overlaps with current region; expand the region + if b.YOverlap(top, bottom) { + region[0] = min(top, b.YTop) + region[1] = max(bottom, b.YBottom) + overlap = true + break + } + } + // does not have overlap with existing region + if !overlap { + regions = append(regions, []float64{b.YTop, b.YBottom}) + } + } + // we're likely to have some duplicate regions; merge so we don't have overlap + return mergeRegions(regions) +} + +// Remove duplicates +func mergeRegions(regions [][]float64) [][]float64 { + newRegions := make([][]float64, 0) + for _, r := range regions { + overlap := false + for i, n := range newRegions { + // this region (n) is inside other (r); skip it + if r[0] <= n[0] && n[1] <= r[1] { + overlap = true + break + } + // this region (n) is completely outside other (r) + if n[0] <= r[0] && r[1] <= n[1] { + overlap = true + break + } + // this region is to the left, but not on the right + if n[0] <= r[0] && n[1] <= r[1] && r[0] <= n[1] { + overlap = true + newRegions[i][0] = n[0] + newRegions[i][1] = r[1] + break + } + // this region is to the right, but not to the left + if r[0] <= n[0] && r[1] <= n[1] && n[0] <= r[1] { + overlap = true + newRegions[i][0] = n[0] + newRegions[i][1] = n[1] + break + } + } + if !overlap { + newRegions = append(newRegions, r) + } + } + return newRegions +} + +func min(f1, f2 float64) float64 { + if f1 < f2 { + return f1 + } + return f2 +} + +func max(f1, f2 float64) float64 { + if f1 < f2 { + return f2 + } + return f1 +} + +// Given non-overlapping regions in x direction and y direction, create +func CartesianProduct(xRegions [][]float64, yRegions [][]float64) [][]Box { + rows := make([][]Box, len(yRegions)) + for i, yRegion := range yRegions { + rows[i] = make([]Box, len(xRegions)) + for j, xRegion := range xRegions { + rows[i][j] = Box{ + XLeft: xRegion[0], + XRight: xRegion[1], + YBottom: yRegion[1], + YTop: yRegion[0], + } + } + } + return rows +} + +type Cell []Box + +func (c Cell) Len() int { + return len(c) +} +func (c Cell) Swap(i, j int) { + c[i], c[j] = c[j], c[i] +} +func (c Cell) Less(i, j int) bool { + // sort boxes by row, then by x + // find row by checking if the bottom y is above the top y. + // within a row, use xLeft + + // different row + if c[i].YBottom < c[j].YTop { + return true // i should be first + } + if c[i].YTop > c[j].YBottom { + return false // i should be first + } + // same row, so compare x + return c[i].XLeft < c[j].XLeft +} + +func Assign(rows [][]Box, boxes []Box) { + // Fill cell with corresponding boxes + // Find word boxes to put in this cell + // Assign table cell (x, y) to each box: + // find all + for i := range rows { + for j := range rows[i] { + c := Cell(boxes) + sort.Sort(c) + boxes = []Box(c) + for _, b := range boxes { + if b.Inside(rows[i][j]) { + rows[i][j].Content = strings.Trim(rows[i][j].Content+" "+b.Content, " ") + } + } + } + } +} + +func ToTable(boxes []Box) [][]string { + // TODO: Explain this better + // Find all regions in x direction with a box, + // and same in y direction + xRegions := XRegions(boxes) + yRegions := YRegions(boxes) + + // Create all cells by taking the cartesian product + // of x regions and y regions: for each x region, all y regions. + rows := CartesianProduct(xRegions, yRegions) + // Assign table cell (x, y) to each box + // (mutates rows) + Assign(rows, boxes) + + // Create table ([][]string) from [][]box.Box + lines := make([][]string, len(rows)) + for i := range rows { + lines[i] = make([]string, len(rows[i])) + for j := range rows[i] { + lines[i][j] = rows[i][j].Content + } + } + return lines +} diff --git a/cmd/lambda/main.go b/cmd/lambda/main.go index 4a26c25..325392c 100644 --- a/cmd/lambda/main.go +++ b/cmd/lambda/main.go @@ -139,12 +139,16 @@ func getTable(file *extract.File) ([][]string, error) { return table, nil } startOCR := time.Now() - output, err := textract.Extract(file) - if err != nil { - return nil, fmt.Errorf("failed to extract: %w", err) - } + // Don't use Textract's Analyze Document, use OCR and custom algorithm instead + // output, err := textract.AnalyzeDocument(file) + // if err != nil { + // return nil, fmt.Errorf("failed to extract: %w", err) + // } + output, err := textract.DetectDocumentText(file) log.Printf("textract: %s", time.Since(startOCR).String()) - table, err := textract.ToTableFromDetectedTable(output) + startAlgorithm := time.Now() + table, err := textract.ToTableFromOCR(output) + log.Printf("ocr-to-table: %s", time.Since(startAlgorithm).String()) if err != nil { return nil, fmt.Errorf("failed to convert to table: %w", err) } diff --git a/dynamodb/dynamodb.go b/dynamodb/dynamodb.go index 244ba99..cb840b9 100644 --- a/dynamodb/dynamodb.go +++ b/dynamodb/dynamodb.go @@ -42,8 +42,10 @@ func PutTable(checksum string, table []byte) error { svc := dynamodb.New(sess) putInput := &dynamodb.PutItemInput{ Item: map[string]*dynamodb.AttributeValue{ - "Checksum": {S: &checksum}, - "JSONTable": {B: table}, + "Checksum": {S: &checksum}, + // Old: Used table detection directly, new uses custom algorithm + // "JSONTable": {B: table}, + "JSONTableCustomDetection": {B: table}, }, TableName: aws.String("Tables"), } @@ -59,7 +61,7 @@ func GetTable(checksum string) ([]byte, error) { return nil, fmt.Errorf("unable to create session: %w", err) } svc := dynamodb.New(sess) - projection := "JSONTable" + projection := "JSONTable,JSONTableCustomDetection" getInput := &dynamodb.GetItemInput{ Key: map[string]*dynamodb.AttributeValue{ "Checksum": {S: &checksum}, @@ -71,9 +73,15 @@ func GetTable(checksum string) ([]byte, error) { if err != nil { return nil, fmt.Errorf("get item: %w", err) } - table, ok := output.Item["JSONTable"] + var table *dynamodb.AttributeValue + var ok bool + table, ok = output.Item["JSONTable"] if !ok { - return nil, nil + table, ok = output.Item["JSONTableCustomDetection"] + if !ok { + return nil, nil + } + return table.B, nil } return table.B, nil } diff --git a/textract/textract.go b/textract/textract.go index ac4d5c8..58cd31f 100644 --- a/textract/textract.go +++ b/textract/textract.go @@ -1,6 +1,7 @@ package textract import ( + "encoding/json" "fmt" "math" "sort" @@ -10,31 +11,33 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/textract" "github.com/vegarsti/extract" + "github.com/vegarsti/extract/box" "github.com/vegarsti/extract/s3" ) -func Extract(file *extract.File) (*textract.AnalyzeDocumentOutput, error) { +func AnalyzeDocument(file *extract.File) (*textract.AnalyzeDocumentOutput, error) { sess, err := session.NewSession() if err != nil { return nil, fmt.Errorf("unable to create session: %w", err) } svc := textract.New(sess) tables := "TABLES" - input := &textract.AnalyzeDocumentInput{ - Document: &textract.Document{Bytes: file.Bytes}, - FeatureTypes: []*string{&tables}, - } if file.ContentType == extract.PDF { - return extractPDF(file) + return analyzePDF(file) } - output, err := svc.AnalyzeDocument(input) + output, err := svc.AnalyzeDocument( + &textract.AnalyzeDocumentInput{ + Document: &textract.Document{Bytes: file.Bytes}, + FeatureTypes: []*string{&tables}, + }, + ) if err != nil { return nil, err } return output, nil } -func extractPDF(file *extract.File) (*textract.AnalyzeDocumentOutput, error) { +func analyzePDF(file *extract.File) (*textract.AnalyzeDocumentOutput, error) { sess, err := session.NewSession() if err != nil { return nil, fmt.Errorf("unable to create session: %w", err) @@ -76,6 +79,46 @@ func extractPDF(file *extract.File) (*textract.AnalyzeDocumentOutput, error) { }, nil } +func ocrPDF(file *extract.File) (*textract.DetectDocumentTextOutput, error) { + sess, err := session.NewSession() + if err != nil { + return nil, fmt.Errorf("unable to create session: %w", err) + } + if err := s3.UploadPDF(file.Checksum, file.Bytes); err != nil { + return nil, fmt.Errorf("upload PDF: %w", err) + } + svc := textract.New(sess) + bucket := "results.extract-table.com" + name := file.Checksum + ".pdf" + startInput := &textract.StartDocumentTextDetectionInput{ + DocumentLocation: &textract.DocumentLocation{ + S3Object: &textract.S3Object{ + Bucket: &bucket, + Name: &name, + }, + }, + } + startOutput, err := svc.StartDocumentTextDetection(startInput) + if err != nil { + return nil, fmt.Errorf("start document analysis: %w", err) + } + getInput := &textract.GetDocumentTextDetectionInput{JobId: startOutput.JobId} + processing := true + var getOutput *textract.GetDocumentTextDetectionOutput + for processing { + time.Sleep(10 * time.Millisecond) + getOutput, err = svc.GetDocumentTextDetection(getInput) + if err != nil { + return nil, fmt.Errorf("get document analysis: %w", err) + } + processing = *getOutput.JobStatus == "IN_PROGRESS" + } + return &textract.DetectDocumentTextOutput{ + Blocks: getOutput.Blocks, + DocumentMetadata: getOutput.DocumentMetadata, + }, nil +} + func ToTableFromDetectedTable(output *textract.AnalyzeDocumentOutput) ([][]string, error) { blocks := make(map[string]*textract.Block) var tables []*textract.Block @@ -132,6 +175,26 @@ func ToTableFromDetectedTable(output *textract.AnalyzeDocumentOutput) ([][]strin return rows, nil } +func DetectDocumentText(file *extract.File) (*textract.DetectDocumentTextOutput, error) { + sess, err := session.NewSession() + if err != nil { + return nil, fmt.Errorf("unable to create session: %w", err) + } + svc := textract.New(sess) + if file.ContentType == extract.PDF { + return ocrPDF(file) + } + output, err := svc.DetectDocumentText( + &textract.DetectDocumentTextInput{ + Document: &textract.Document{Bytes: file.Bytes}, + }, + ) + if err != nil { + return nil, err + } + return output, nil +} + func textInCellBlock(blocks map[string]*textract.Block, cell *textract.Block) string { var words []string for _, r := range cell.Relationships { @@ -201,3 +264,54 @@ func toTable(rows [][]extract.Word, splitAt []float64, splitFunc func([]extract. } return table } + +func ToTableFromOCR(output *textract.DetectDocumentTextOutput) ([][]string, error) { + blocks := make(map[string]*textract.Block) + words := 0 + for _, block := range output.Blocks { + blocks[*block.Id] = block + if *block.BlockType != "WORD" { + words++ + } + } + rowMap := make(map[int]map[int]string) + for _, cell := range blocks { + if *cell.BlockType == "CELL" { + rowIndex := int(*cell.RowIndex) + colIndex := int(*cell.ColumnIndex) + if _, ok := rowMap[rowIndex]; !ok { + rowMap[rowIndex] = make(map[int]string) + } + rowMap[rowIndex][colIndex] = textInCellBlock(blocks, cell) + } + } + // Debug printing + // fmt.Printf("%+v", rowMap) + // fmt.Printf("%+v", blocks) + + boxes := make([]box.Box, words) + for _, cell := range blocks { + if *cell.BlockType != "WORD" { + continue + } + box := box.Box{ + XLeft: *cell.Geometry.BoundingBox.Left, + XRight: *cell.Geometry.BoundingBox.Left + *cell.Geometry.BoundingBox.Width, + YTop: *cell.Geometry.BoundingBox.Top, + YBottom: *cell.Geometry.BoundingBox.Top + *cell.Geometry.BoundingBox.Height, + Content: *cell.Text, + } + // Debug printing + // fmt.Printf("left: %+v\n", *cell.Geometry.BoundingBox.Left) + // fmt.Printf("top: %+v\n", *cell.Geometry.BoundingBox.Top) + // fmt.Printf("width: %+v\n", *cell.Geometry.BoundingBox.Width) + // fmt.Printf("height: %+v\n", *cell.Geometry.BoundingBox.Height) + // fmt.Printf("%+v\n", box) + boxes = append(boxes, box) + } + table := box.ToTable(boxes) + tableJSON, _ := json.MarshalIndent(table, "", " ") + fmt.Println(string(tableJSON)) + + return nil, nil +}