276 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			276 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						||
 * Copyright 2025 coze-dev Authors
 | 
						||
 *
 | 
						||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						||
 * you may not use this file except in compliance with the License.
 | 
						||
 * You may obtain a copy of the License at
 | 
						||
 *
 | 
						||
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
						||
 *
 | 
						||
 * Unless required by applicable law or agreed to in writing, software
 | 
						||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						||
 * See the License for the specific language governing permissions and
 | 
						||
 * limitations under the License.
 | 
						||
 */
 | 
						||
 | 
						||
package builtin
 | 
						||
 | 
						||
import (
 | 
						||
	"context"
 | 
						||
	"encoding/base64"
 | 
						||
	"encoding/json"
 | 
						||
	"fmt"
 | 
						||
	"io"
 | 
						||
	"os"
 | 
						||
	"os/exec"
 | 
						||
	"strings"
 | 
						||
 | 
						||
	"github.com/cloudwego/eino/components/document/parser"
 | 
						||
	"github.com/cloudwego/eino/schema"
 | 
						||
 | 
						||
	"github.com/coze-dev/coze-studio/backend/infra/contract/document"
 | 
						||
	"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
 | 
						||
	contract "github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
 | 
						||
	"github.com/coze-dev/coze-studio/backend/infra/contract/storage"
 | 
						||
)
 | 
						||
 | 
						||
const (
 | 
						||
	contentTypeText  = "text"
 | 
						||
	contentTypeImage = "image"
 | 
						||
	contentTypeTable = "table"
 | 
						||
)
 | 
						||
 | 
						||
type pyParseRequest struct {
 | 
						||
	ExtractImages bool  `json:"extract_images"`
 | 
						||
	ExtractTables bool  `json:"extract_tables"`
 | 
						||
	FilterPages   []int `json:"filter_pages"`
 | 
						||
}
 | 
						||
 | 
						||
type pyParseResult struct {
 | 
						||
	Error   string            `json:"error"`
 | 
						||
	Content []*pyParseContent `json:"content"`
 | 
						||
}
 | 
						||
 | 
						||
type pyParseContent struct {
 | 
						||
	Type    string     `json:"type"`
 | 
						||
	Content string     `json:"content"`
 | 
						||
	Table   [][]string `json:"table"`
 | 
						||
	Page    int        `json:"page"`
 | 
						||
}
 | 
						||
 | 
						||
type pyPDFTableIterator struct {
 | 
						||
	i    int
 | 
						||
	rows [][]string
 | 
						||
}
 | 
						||
 | 
						||
func (p *pyPDFTableIterator) NextRow() (row []string, end bool, err error) {
 | 
						||
	if p.i >= len(p.rows) {
 | 
						||
		return nil, true, nil
 | 
						||
	}
 | 
						||
	row = p.rows[p.i]
 | 
						||
	p.i++
 | 
						||
	return row, false, nil
 | 
						||
}
 | 
						||
 | 
						||
func parseByPython(config *contract.Config, storage storage.Storage, ocr ocr.OCR, pyPath, scriptPath string) parseFn {
 | 
						||
	return func(ctx context.Context, reader io.Reader, opts ...parser.Option) (docs []*schema.Document, err error) {
 | 
						||
		pr, pw, err := os.Pipe()
 | 
						||
		if err != nil {
 | 
						||
			return nil, fmt.Errorf("[parseByPython] create rpipe failed, %w", err)
 | 
						||
		}
 | 
						||
		r, w, err := os.Pipe()
 | 
						||
		if err != nil {
 | 
						||
			return nil, fmt.Errorf("[parseByPython] create pipe failed: %w", err)
 | 
						||
		}
 | 
						||
		options := parser.GetCommonOptions(&parser.Options{ExtraMeta: map[string]any{}}, opts...)
 | 
						||
 | 
						||
		reqb, err := json.Marshal(pyParseRequest{
 | 
						||
			ExtractImages: config.ParsingStrategy.ExtractImage,
 | 
						||
			ExtractTables: config.ParsingStrategy.ExtractTable,
 | 
						||
			FilterPages:   config.ParsingStrategy.FilterPages,
 | 
						||
		})
 | 
						||
		if err != nil {
 | 
						||
			return nil, fmt.Errorf("[parseByPython] create parse request failed, %w", err)
 | 
						||
		}
 | 
						||
		if _, err = pw.Write(reqb); err != nil {
 | 
						||
			return nil, fmt.Errorf("[parseByPython] write parse request bytes failed, %w", err)
 | 
						||
		}
 | 
						||
		if err = pw.Close(); err != nil {
 | 
						||
			return nil, fmt.Errorf("[parseByPython] close write request pipe failed, %w", err)
 | 
						||
		}
 | 
						||
 | 
						||
		cmd := exec.Command(pyPath, scriptPath)
 | 
						||
		cmd.Stdin = reader
 | 
						||
		cmd.Stdout = os.Stdout
 | 
						||
		cmd.ExtraFiles = []*os.File{w, pr}
 | 
						||
		if err = cmd.Start(); err != nil {
 | 
						||
			return nil, fmt.Errorf("[parseByPython] failed to start Python script: %w", err)
 | 
						||
		}
 | 
						||
		if err = w.Close(); err != nil {
 | 
						||
			return nil, fmt.Errorf("[parseByPython] failed to close write pipe: %w", err)
 | 
						||
		}
 | 
						||
 | 
						||
		result := &pyParseResult{}
 | 
						||
 | 
						||
		if err = json.NewDecoder(r).Decode(result); err != nil {
 | 
						||
			return nil, fmt.Errorf("[parseByPython] failed to decode result: %w", err)
 | 
						||
		}
 | 
						||
		if err = cmd.Wait(); err != nil {
 | 
						||
			return nil, fmt.Errorf("[parseByPython] cmd wait err: %w", err)
 | 
						||
		}
 | 
						||
 | 
						||
		if result.Error != "" {
 | 
						||
			return nil, fmt.Errorf("[parseByPython] python execution failed: %s", result.Error)
 | 
						||
		}
 | 
						||
 | 
						||
		for i, item := range result.Content {
 | 
						||
			switch item.Type {
 | 
						||
			case contentTypeText:
 | 
						||
				partDocs, err := chunkCustom(ctx, item.Content, config, opts...)
 | 
						||
				if err != nil {
 | 
						||
					return nil, fmt.Errorf("[parseByPython] chunk text failed, %w", err)
 | 
						||
				}
 | 
						||
				docs = append(docs, partDocs...)
 | 
						||
			case contentTypeImage:
 | 
						||
				if !config.ParsingStrategy.ExtractImage {
 | 
						||
					continue
 | 
						||
				}
 | 
						||
				image, err := base64.StdEncoding.DecodeString(item.Content)
 | 
						||
				if err != nil {
 | 
						||
					return nil, fmt.Errorf("[parseByPython] decode image failed, %w", err)
 | 
						||
				}
 | 
						||
				imgSrc, err := putImageObject(ctx, storage, "png", getCreatorIDFromExtraMeta(options.ExtraMeta), image)
 | 
						||
				if err != nil {
 | 
						||
					return nil, err
 | 
						||
				}
 | 
						||
				label := fmt.Sprintf("\n%s", imgSrc)
 | 
						||
				if config.ParsingStrategy.ImageOCR && ocr != nil {
 | 
						||
					texts, err := ocr.FromBase64(ctx, item.Content)
 | 
						||
					if err != nil {
 | 
						||
						return nil, fmt.Errorf("[parseByPython] FromBase64 failed, %w", err)
 | 
						||
					}
 | 
						||
					label += strings.Join(texts, "\n")
 | 
						||
				}
 | 
						||
 | 
						||
				if i == len(result.Content)-1 || result.Content[i+1].Type != "text" {
 | 
						||
					doc := &schema.Document{
 | 
						||
						Content:  label,
 | 
						||
						MetaData: map[string]any{},
 | 
						||
					}
 | 
						||
					for k, v := range options.ExtraMeta {
 | 
						||
						doc.MetaData[k] = v
 | 
						||
					}
 | 
						||
					docs = append(docs, doc)
 | 
						||
				} else {
 | 
						||
					// TODO: 这里有点问题,img label 可能被较短的 chunk size 截断
 | 
						||
					result.Content[i+1].Content = label + result.Content[i+1].Content
 | 
						||
				}
 | 
						||
			case contentTypeTable:
 | 
						||
				if !config.ParsingStrategy.ExtractTable {
 | 
						||
					continue
 | 
						||
				}
 | 
						||
				iterator := &pyPDFTableIterator{i: 0, rows: item.Table}
 | 
						||
				rawTableDocs, err := parseByRowIterator(iterator, &contract.Config{
 | 
						||
					FileExtension: contract.FileExtensionCSV,
 | 
						||
					ParsingStrategy: &contract.ParsingStrategy{
 | 
						||
						HeaderLine:    0,
 | 
						||
						DataStartLine: 1,
 | 
						||
						RowsCount:     0,
 | 
						||
					},
 | 
						||
					ChunkingStrategy: config.ChunkingStrategy,
 | 
						||
				}, opts...)
 | 
						||
				if err != nil {
 | 
						||
					return nil, fmt.Errorf("[parseByPython] parse table failed, %w", err)
 | 
						||
				}
 | 
						||
				fmtTableDocs, err := formatTablesInDocument(rawTableDocs)
 | 
						||
				if err != nil {
 | 
						||
					return nil, fmt.Errorf("[parseByPython] format table failed, %w", err)
 | 
						||
				}
 | 
						||
				docs = append(docs, fmtTableDocs...)
 | 
						||
			default:
 | 
						||
				return nil, fmt.Errorf("[parseByPython] invalid content type: %s", item.Type)
 | 
						||
			}
 | 
						||
		}
 | 
						||
 | 
						||
		return docs, nil
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
func formatTablesInDocument(input []*schema.Document) (output []*schema.Document, err error) {
 | 
						||
	const (
 | 
						||
		maxSize              = 65535
 | 
						||
		tableStart, tableEnd = "<table>", "</table>"
 | 
						||
	)
 | 
						||
 | 
						||
	var (
 | 
						||
		buffer   strings.Builder
 | 
						||
		firstDoc *schema.Document
 | 
						||
	)
 | 
						||
 | 
						||
	endSize := len(tableEnd)
 | 
						||
	buffer.WriteString(tableStart)
 | 
						||
 | 
						||
	push := func() {
 | 
						||
		newDoc := &schema.Document{
 | 
						||
			Content:  buffer.String() + tableEnd,
 | 
						||
			MetaData: map[string]any{},
 | 
						||
		}
 | 
						||
		for k, v := range firstDoc.MetaData {
 | 
						||
			if k == document.MetaDataKeyColumnData {
 | 
						||
				continue
 | 
						||
			}
 | 
						||
			newDoc.MetaData[k] = v
 | 
						||
		}
 | 
						||
		output = append(output, newDoc)
 | 
						||
		buffer.Reset()
 | 
						||
		buffer.WriteString(tableStart)
 | 
						||
	}
 | 
						||
 | 
						||
	write := func(contents []string) {
 | 
						||
		row := fmt.Sprintf("<tr><td>%s</td></tr>", strings.Join(contents, "</td><td>"))
 | 
						||
		buffer.WriteString(row)
 | 
						||
		if buffer.Len()+endSize >= maxSize {
 | 
						||
			push()
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	for i := range input {
 | 
						||
		doc := input[i]
 | 
						||
 | 
						||
		if i == 0 {
 | 
						||
			firstDoc = doc
 | 
						||
			cols, err := document.GetDocumentColumns(doc)
 | 
						||
			if err != nil {
 | 
						||
				return nil, fmt.Errorf("[formatTablesInDocument] invalid table columns, %w", err)
 | 
						||
			}
 | 
						||
			values := make([]string, 0, len(cols))
 | 
						||
			for _, col := range cols {
 | 
						||
				values = append(values, col.Name)
 | 
						||
			}
 | 
						||
			write(values)
 | 
						||
			if colOnly, err := document.GetDocumentColumnsOnly(doc); err != nil {
 | 
						||
				return nil, err
 | 
						||
			} else if colOnly {
 | 
						||
				break
 | 
						||
			}
 | 
						||
		}
 | 
						||
 | 
						||
		data, err := document.GetDocumentColumnData(doc)
 | 
						||
		if err != nil {
 | 
						||
			return nil, fmt.Errorf("[formatTablesInDocument] invalid table data, %w", err)
 | 
						||
		}
 | 
						||
		values := make([]string, 0, len(data))
 | 
						||
		for _, col := range data {
 | 
						||
			values = append(values, col.GetNullableStringValue())
 | 
						||
		}
 | 
						||
		write(values)
 | 
						||
	}
 | 
						||
 | 
						||
	if buffer.String() != tableStart {
 | 
						||
		push()
 | 
						||
	}
 | 
						||
 | 
						||
	return
 | 
						||
}
 |