coze-studio/backend/infra/impl/document/nl2sql/builtin/nl2sql_test.go

140 lines
5.6 KiB
Go
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright 2025 coze-dev Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package builtin
import (
"context"
"testing"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/stretchr/testify/assert"
"github.com/coze-dev/coze-studio/backend/infra/contract/document"
)
func TestNL2SQL(t *testing.T) {
ctx := context.Background()
t.Run("test table meta not provided", func(t *testing.T) {
impl, err := NewNL2SQL(ctx, &mockChatModel{"mock resp"}, prompt.FromMessages(schema.Jinja2,
schema.SystemMessage("system message 123"),
schema.UserMessage("{{messages}}, {{table_meta}}"),
))
assert.NoError(t, err)
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, nil)
assert.Error(t, err)
assert.Equal(t, "", sql)
})
t.Run("test parse failed", func(t *testing.T) {
impl, err := NewNL2SQL(ctx, &mockChatModel{"mock resp"}, prompt.FromMessages(schema.Jinja2,
schema.SystemMessage("system message 123"),
schema.UserMessage("{{messages}}, {{table_meta}}"),
))
assert.NoError(t, err)
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, []*document.TableSchema{
{
Name: "mock_table_1",
Comment: "hello",
Columns: []*document.Column{
{
ID: 121,
Name: "id",
Type: document.TableColumnTypeInteger,
Description: "test",
Nullable: false,
IsPrimary: true,
Sequence: 0,
},
{
ID: 123,
Name: "col_1",
Type: document.TableColumnTypeString,
Description: "column_1",
Nullable: true,
IsPrimary: false,
Sequence: 1,
},
},
},
})
assert.Error(t, err)
assert.Equal(t, "", sql)
})
t.Run("test success", func(t *testing.T) {
impl, err := NewNL2SQL(ctx, &mockChatModel{`{"sql":"mock sql","err_code":0,"err_msg":""}`}, prompt.FromMessages(schema.Jinja2,
schema.SystemMessage("system message 123"),
schema.UserMessage("{{messages}}, {{table_meta}}"),
))
assert.NoError(t, err)
sql, err := impl.NL2SQL(ctx, []*schema.Message{schema.UserMessage("hello")}, []*document.TableSchema{
{
Name: "mock_table_1",
Comment: "hello",
Columns: []*document.Column{
{
ID: 121,
Name: "id",
Type: document.TableColumnTypeInteger,
Description: "test",
Nullable: false,
IsPrimary: true,
Sequence: 0,
},
{
ID: 123,
Name: "col_1",
Type: document.TableColumnTypeString,
Description: "column_1",
Nullable: true,
IsPrimary: false,
Sequence: 1,
},
},
},
})
assert.NoError(t, err)
assert.Equal(t, "mock sql", sql)
})
}
type mockChatModel struct {
content string
}
func (m mockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
return schema.AssistantMessage(m.content, nil), nil
}
func (m mockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
return nil, nil
}
func (m mockChatModel) BindTools(tools []*schema.ToolInfo) error {
return nil
}
const sys = "# Role: NL2SQL Consultant\n\n## Goals\nTranslate natural language statements into SQL queries in MySQL standard. Follow the Constraints and return only a JSON always.\n\n## Format\n- JSON format only. JSON contains field \"sql\" for generated SQL, filed \"err_code\" for reason type, field \"err_msg\" for detail reason (prefer more than 10 words)\n- Don't use \"```json\" markdown format\n\n## Skills\n- Good at Translate natural language statements into SQL queries in MySQL standard.\n\n## Define\n\"err_code\" Reason Type Define:\n- 0 means you generated a SQL\n- 3002 means you cannot generate a SQL because of timeout\n- 3003 means you cannot generate a SQL because of table schema missing\n- 3005 means you cannot generate a SQL because of some term is ambiguous\n\n## Example\nQ: Help me implement NL2SQL.\n.table schema description: CREATE TABLE `sales_records` (\\n `sales_id` bigint(20) unsigned NOT NULL COMMENT 'id of sales person',\\n `product_id` bigint(64) COMMENT 'id of product',\\n `sale_date` datetime(3) COMMENT 'sold date and time',\\n `quantity_sold` int(11) COMMENT 'sold amount',\\n PRIMARY KEY (`sales_id`)\\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='销售记录表';\n.natural language description of the SQL requirement: ​​​​查询上月的销量总额第一名的销售员和他的销售总额\nA: {\n \"sql\":\"SELECT sales_id, SUM(quantity_sold) AS total_sales FROM sales_records WHERE MONTH(sale_date) = MONTH(CURRENT_DATE - INTERVAL 1 MONTH) AND YEAR(sale_date) = YEAR(CURRENT_DATE - INTERVAL 1 MONTH) GROUP BY sales_id ORDER BY total_sales DESC LIMIT 1\",\n \"err_code\":0,\n \"err_msg\":\"SQL query generated successfully\"\n}"
const usr = "help me implement NL2SQL.\ntable schema description:{{tableSchema}}\nnatural language description of the SQL requirement: {{chat_history}}."