diff --git a/backend/api/handler/coze/workflow_service_test.go b/backend/api/handler/coze/workflow_service_test.go index e71fff9f..fd2f0ceb 100644 --- a/backend/api/handler/coze/workflow_service_test.go +++ b/backend/api/handler/coze/workflow_service_test.go @@ -41,7 +41,6 @@ import ( "github.com/cloudwego/hertz/pkg/common/ut" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/sse" - "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -86,6 +85,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/service" "github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" + "github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis" "github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint" "github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct" mockCrossUser "github.com/coze-dev/coze-studio/backend/internal/mock/crossdomain/crossuser" @@ -239,9 +239,7 @@ func newWfTestRunner(t *testing.T) *wfTestRunner { t.Fatalf("Failed to start miniredis: %v", err) } - redisClient := redis.NewClient(&redis.Options{ - Addr: s.Addr(), - }) + redisClient := redis.NewWithAddrAndPassword(s.Addr(), "") cpStore := checkpoint.NewRedisStore(redisClient) diff --git a/backend/application/base/appinfra/app_infra.go b/backend/application/base/appinfra/app_infra.go index 2aabe24a..23409fb1 100644 --- a/backend/application/base/appinfra/app_infra.go +++ b/backend/application/base/appinfra/app_infra.go @@ -25,6 +25,7 @@ import ( "gorm.io/gorm" + "github.com/coze-dev/coze-studio/backend/infra/contract/cache" "github.com/coze-dev/coze-studio/backend/infra/contract/coderunner" "github.com/coze-dev/coze-studio/backend/infra/contract/imagex" "github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr" @@ -42,7 +43,7 @@ import ( type AppDependencies struct { DB *gorm.DB - CacheCli *redis.Client + CacheCli cache.Cmdable IDGenSVC idgen.IDGenerator ESClient es.Client ImageXClient imagex.ImageX diff --git a/backend/application/search/init.go b/backend/application/search/init.go index f11cc303..430e0c30 100644 --- a/backend/application/search/init.go +++ b/backend/application/search/init.go @@ -33,9 +33,9 @@ import ( search "github.com/coze-dev/coze-studio/backend/domain/search/service" user "github.com/coze-dev/coze-studio/backend/domain/user/service" "github.com/coze-dev/coze-studio/backend/domain/workflow" + "github.com/coze-dev/coze-studio/backend/infra/contract/cache" "github.com/coze-dev/coze-studio/backend/infra/contract/es" "github.com/coze-dev/coze-studio/backend/infra/contract/storage" - "github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis" "github.com/coze-dev/coze-studio/backend/infra/impl/eventbus" "github.com/coze-dev/coze-studio/backend/pkg/logs" "github.com/coze-dev/coze-studio/backend/types/consts" @@ -43,7 +43,7 @@ import ( type ServiceComponents struct { DB *gorm.DB - Cache *redis.Client + Cache cache.Cmdable TOS storage.Storage ESClient es.Client ProjectEventBus ProjectEventBus diff --git a/backend/crossdomain/workflow/plugin/plugin_test.go b/backend/crossdomain/workflow/plugin/plugin_test.go index 66c93e33..eaf632b7 100644 --- a/backend/crossdomain/workflow/plugin/plugin_test.go +++ b/backend/crossdomain/workflow/plugin/plugin_test.go @@ -19,8 +19,9 @@ package plugin import ( "testing" - "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" "github.com/stretchr/testify/assert" + + "github.com/coze-dev/coze-studio/backend/domain/workflow/entity/vo" ) func TestPluginService_UnwrapArrayItemFieldsInVariable(t *testing.T) { @@ -214,4 +215,4 @@ func TestPluginService_UnwrapArrayItemFieldsInVariable(t *testing.T) { err := s.UnwrapArrayItemFieldsInVariable(nil) assert.NoError(t, err) }) -} \ No newline at end of file +} diff --git a/backend/domain/memory/database/service/database_impl.go b/backend/domain/memory/database/service/database_impl.go index d2dbdb37..651d78f0 100644 --- a/backend/domain/memory/database/service/database_impl.go +++ b/backend/domain/memory/database/service/database_impl.go @@ -28,7 +28,6 @@ import ( "strconv" "time" - "github.com/redis/go-redis/v9" "github.com/tealeg/xlsx/v3" "gorm.io/gorm" @@ -1921,22 +1920,22 @@ func (d databaseService) GetDatabaseFileProgressData(ctx context.Context, req *G currentFileName = draftCurrentFileName } totalNum, err := d.cache.Get(ctx, fmt.Sprintf(totalKey, req.DatabaseID, req.UserID)).Int64() - if err != nil && !errors.Is(err, redis.Nil) { + if err != nil && !errors.Is(err, cache.Nil) { return nil, err } progressNum, err := d.cache.Get(ctx, fmt.Sprintf(progressKey, req.DatabaseID, req.UserID)).Int64() - if err != nil && !errors.Is(err, redis.Nil) { + if err != nil && !errors.Is(err, cache.Nil) { return nil, err } failReason, err := d.cache.Get(ctx, fmt.Sprintf(failKey, req.DatabaseID, req.UserID)).Result() - if err != nil && !errors.Is(err, redis.Nil) { + if err != nil && !errors.Is(err, cache.Nil) { return nil, err } fileName, err := d.cache.Get(ctx, fmt.Sprintf(currentFileName, req.DatabaseID, req.UserID)).Result() - if err != nil && !errors.Is(err, redis.Nil) { + if err != nil && !errors.Is(err, cache.Nil) { return nil, err } diff --git a/backend/domain/workflow/internal/canvas/adaptor/to_schema.go b/backend/domain/workflow/internal/canvas/adaptor/to_schema.go index 93c05b36..59ba16e1 100644 --- a/backend/domain/workflow/internal/canvas/adaptor/to_schema.go +++ b/backend/domain/workflow/internal/canvas/adaptor/to_schema.go @@ -438,7 +438,7 @@ func PruneIsolatedNodes(nodes []*vo.Node, edges []*vo.Edge, parentNode *vo.Node) func parseBatchMode(n *vo.Node) ( batchN *vo.Node, // the new batch node - enabled bool, // whether the node has enabled batch mode + enabled bool, // whether the node has enabled batch mode err error) { if n.Data == nil || n.Data.Inputs == nil { return nil, false, nil diff --git a/backend/domain/workflow/internal/compose/test/question_answer_test.go b/backend/domain/workflow/internal/compose/test/question_answer_test.go index 4ad1b5e6..549a7071 100644 --- a/backend/domain/workflow/internal/compose/test/question_answer_test.go +++ b/backend/domain/workflow/internal/compose/test/question_answer_test.go @@ -31,7 +31,6 @@ import ( model2 "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/schema" - "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "gorm.io/driver/mysql" @@ -47,6 +46,7 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/exit" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/nodes/qa" repo2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo" + "github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis" schema2 "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/schema" "github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint" mock "github.com/coze-dev/coze-studio/backend/internal/mock/infra/contract/idgen" @@ -96,9 +96,7 @@ func TestQuestionAnswer(t *testing.T) { } defer s.Close() - redisClient := redis.NewClient(&redis.Options{ - Addr: s.Addr(), - }) + redisClient := redis.NewWithAddrAndPassword(s.Addr(), "") mockIDGen := mock.NewMockIDGenerator(ctrl) mockIDGen.EXPECT().GenID(gomock.Any()).Return(time.Now().UnixNano(), nil).AnyTimes() diff --git a/backend/domain/workflow/internal/repo/execute_history_store_test.go b/backend/domain/workflow/internal/repo/execute_history_store_test.go index 37a7f704..5987da5c 100644 --- a/backend/domain/workflow/internal/repo/execute_history_store_test.go +++ b/backend/domain/workflow/internal/repo/execute_history_store_test.go @@ -25,7 +25,6 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/alicebob/miniredis/v2" - "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "gorm.io/driver/mysql" @@ -33,12 +32,14 @@ import ( "github.com/coze-dev/coze-studio/backend/domain/workflow/entity" "github.com/coze-dev/coze-studio/backend/domain/workflow/internal/repo/dal/query" + "github.com/coze-dev/coze-studio/backend/infra/contract/cache" + "github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis" ) type ExecuteHistoryStoreSuite struct { suite.Suite db *gorm.DB - redis *redis.Client + redis cache.Cmdable mock sqlmock.Sqlmock store *executeHistoryStoreImpl } @@ -47,7 +48,7 @@ func (s *ExecuteHistoryStoreSuite) SetupTest() { var err error mr, err := miniredis.Run() assert.NoError(s.T(), err) - s.redis = redis.NewClient(&redis.Options{Addr: mr.Addr()}) + s.redis = redis.NewWithAddrAndPassword(mr.Addr(), "") mockDB, mock, err := sqlmock.New() assert.NoError(s.T(), err) diff --git a/backend/infra/contract/cache/cache.go b/backend/infra/contract/cache/cache.go index 0555ea04..c1f0181c 100644 --- a/backend/infra/contract/cache/cache.go +++ b/backend/infra/contract/cache/cache.go @@ -17,9 +17,96 @@ package cache import ( - "github.com/redis/go-redis/v9" + "context" + "time" ) -type Cmdable = redis.Cmdable +var Nil error -const Nil = redis.Nil +func SetDefaultNilError(err error) { + Nil = err +} + +type Cmdable interface { + Pipeline() Pipeliner + StringCmdable + HashCmdable + GenericCmdable + ListCmdable +} + +type StringCmdable interface { + Set(ctx context.Context, key string, value interface{}, expiration time.Duration) StatusCmd + Get(ctx context.Context, key string) StringCmd + IncrBy(ctx context.Context, key string, value int64) IntCmd + Incr(ctx context.Context, key string) IntCmd +} + +type HashCmdable interface { + HSet(ctx context.Context, key string, values ...interface{}) IntCmd + HGetAll(ctx context.Context, key string) MapStringStringCmd +} + +type GenericCmdable interface { + Del(ctx context.Context, keys ...string) IntCmd + Exists(ctx context.Context, keys ...string) IntCmd + Expire(ctx context.Context, key string, expiration time.Duration) BoolCmd +} + +type Pipeliner interface { + StatefulCmdable + Exec(ctx context.Context) ([]Cmder, error) +} + +type StatefulCmdable interface { + Cmdable +} + +type ListCmdable interface { + LIndex(ctx context.Context, key string, index int64) StringCmd + LPush(ctx context.Context, key string, values ...interface{}) IntCmd + RPush(ctx context.Context, key string, values ...interface{}) IntCmd + LSet(ctx context.Context, key string, index int64, value interface{}) StatusCmd + LPop(ctx context.Context, key string) StringCmd + LRange(ctx context.Context, key string, start, stop int64) StringSliceCmd +} +type Cmder interface { + Err() error +} + +type baseCmd interface { + Err() error +} + +type IntCmd interface { + baseCmd + Result() (int64, error) +} + +type MapStringStringCmd interface { + baseCmd + Result() (map[string]string, error) +} + +type BoolCmd interface { + baseCmd + Result() (bool, error) +} + +type StatusCmd interface { + baseCmd + Result() (string, error) +} + +type StringCmd interface { + baseCmd + Result() (string, error) + Val() string + Int64() (int64, error) + Bytes() ([]byte, error) +} + +type StringSliceCmd interface { + baseCmd + Result() ([]string, error) +} diff --git a/backend/infra/impl/cache/redis/redis.go b/backend/infra/impl/cache/redis/redis.go index 5c2ed55c..4e6de603 100644 --- a/backend/infra/impl/cache/redis/redis.go +++ b/backend/infra/impl/cache/redis/redis.go @@ -17,17 +17,24 @@ package redis import ( + "context" "os" "time" "github.com/redis/go-redis/v9" + + "github.com/coze-dev/coze-studio/backend/infra/contract/cache" ) -type Client = redis.Client - -func New() *redis.Client { +func New() cache.Cmdable { addr := os.Getenv("REDIS_ADDR") password := os.Getenv("REDIS_PASSWORD") + cache.SetDefaultNilError(redis.Nil) + + return NewWithAddrAndPassword(addr, password) +} + +func NewWithAddrAndPassword(addr, password string) cache.Cmdable { rdb := redis.NewClient(&redis.Options{ Addr: addr, // Redis地址 DB: 0, // 默认数据库 @@ -44,5 +51,199 @@ func New() *redis.Client { WriteTimeout: 3 * time.Second, // write operation timed out }) - return rdb + return &redisImpl{client: rdb} +} + +type redisImpl struct { + client *redis.Client +} + +// Del implements cache.Cmdable. +func (r *redisImpl) Del(ctx context.Context, keys ...string) cache.IntCmd { + return r.client.Del(ctx, keys...) +} + +// Exists implements cache.Cmdable. +func (r *redisImpl) Exists(ctx context.Context, keys ...string) cache.IntCmd { + return r.client.Exists(ctx, keys...) +} + +// Expire implements cache.Cmdable. +func (r *redisImpl) Expire(ctx context.Context, key string, expiration time.Duration) cache.BoolCmd { + return r.client.Expire(ctx, key, expiration) +} + +// Get implements cache.Cmdable. +func (r *redisImpl) Get(ctx context.Context, key string) cache.StringCmd { + return r.client.Get(ctx, key) +} + +// HGetAll implements cache.Cmdable. +func (r *redisImpl) HGetAll(ctx context.Context, key string) cache.MapStringStringCmd { + return r.client.HGetAll(ctx, key) +} + +// HSet implements cache.Cmdable. +func (r *redisImpl) HSet(ctx context.Context, key string, values ...interface{}) cache.IntCmd { + return r.client.HSet(ctx, key, values...) +} + +// Incr implements cache.Cmdable. +func (r *redisImpl) Incr(ctx context.Context, key string) cache.IntCmd { + return r.client.Incr(ctx, key) +} + +// IncrBy implements cache.Cmdable. +func (r *redisImpl) IncrBy(ctx context.Context, key string, value int64) cache.IntCmd { + return r.client.IncrBy(ctx, key, value) +} + +// LIndex implements cache.Cmdable. +func (r *redisImpl) LIndex(ctx context.Context, key string, index int64) cache.StringCmd { + return r.client.LIndex(ctx, key, index) +} + +// LPop implements cache.Cmdable. +func (r *redisImpl) LPop(ctx context.Context, key string) cache.StringCmd { + return r.client.LPop(ctx, key) +} + +// LPush implements cache.Cmdable. +func (r *redisImpl) LPush(ctx context.Context, key string, values ...interface{}) cache.IntCmd { + return r.client.LPush(ctx, key, values...) +} + +// LRange implements cache.Cmdable. +func (r *redisImpl) LRange(ctx context.Context, key string, start int64, stop int64) cache.StringSliceCmd { + return r.client.LRange(ctx, key, start, stop) +} + +// LSet implements cache.Cmdable. +func (r *redisImpl) LSet(ctx context.Context, key string, index int64, value interface{}) cache.StatusCmd { + return r.client.LSet(ctx, key, index, value) +} + +// Pipeline implements cache.Cmdable. +func (r *redisImpl) Pipeline() cache.Pipeliner { + p := r.client.Pipeline() + return &pipelineImpl{p: p} +} + +// RPush implements cache.Cmdable. +func (r *redisImpl) RPush(ctx context.Context, key string, values ...interface{}) cache.IntCmd { + return r.client.RPush(ctx, key, values...) +} + +// Set implements cache.Cmdable. +func (r *redisImpl) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) cache.StatusCmd { + return r.client.Set(ctx, key, value, expiration) +} + +type pipelineImpl struct { + p redis.Pipeliner +} + +// Del implements cache.Pipeliner. +func (p *pipelineImpl) Del(ctx context.Context, keys ...string) cache.IntCmd { + return p.p.Del(ctx, keys...) +} + +// Exec implements cache.Pipeliner. +func (p *pipelineImpl) Exec(ctx context.Context) ([]cache.Cmder, error) { + cmders, err := p.p.Exec(ctx) + if err != nil { + return nil, err + } + return convertCmders(cmders), nil +} + +func convertCmders(cmders []redis.Cmder) []cache.Cmder { + res := make([]cache.Cmder, 0, len(cmders)) + for _, cmder := range cmders { + res = append(res, &cmderImpl{cmder: cmder}) + } + return res +} + +type cmderImpl struct { + cmder redis.Cmder +} + +func (c *cmderImpl) Err() error { + return c.cmder.Err() +} + +// Exists implements cache.Pipeliner. +func (p *pipelineImpl) Exists(ctx context.Context, keys ...string) cache.IntCmd { + return p.p.Exists(ctx, keys...) +} + +// Expire implements cache.Pipeliner. +func (p *pipelineImpl) Expire(ctx context.Context, key string, expiration time.Duration) cache.BoolCmd { + return p.p.Expire(ctx, key, expiration) +} + +// Get implements cache.Pipeliner. +func (p *pipelineImpl) Get(ctx context.Context, key string) cache.StringCmd { + return p.p.Get(ctx, key) +} + +// HGetAll implements cache.Pipeliner. +func (p *pipelineImpl) HGetAll(ctx context.Context, key string) cache.MapStringStringCmd { + return p.p.HGetAll(ctx, key) +} + +// HSet implements cache.Pipeliner. +func (p *pipelineImpl) HSet(ctx context.Context, key string, values ...interface{}) cache.IntCmd { + return p.p.HSet(ctx, key, values...) +} + +// Incr implements cache.Pipeliner. +func (p *pipelineImpl) Incr(ctx context.Context, key string) cache.IntCmd { + return p.p.Incr(ctx, key) +} + +// IncrBy implements cache.Pipeliner. +func (p *pipelineImpl) IncrBy(ctx context.Context, key string, value int64) cache.IntCmd { + return p.p.IncrBy(ctx, key, value) +} + +// LIndex implements cache.Pipeliner. +func (p *pipelineImpl) LIndex(ctx context.Context, key string, index int64) cache.StringCmd { + return p.p.LIndex(ctx, key, index) +} + +// LPop implements cache.Pipeliner. +func (p *pipelineImpl) LPop(ctx context.Context, key string) cache.StringCmd { + return p.p.LPop(ctx, key) +} + +// LPush implements cache.Pipeliner. +func (p *pipelineImpl) LPush(ctx context.Context, key string, values ...interface{}) cache.IntCmd { + return p.p.LPush(ctx, key, values...) +} + +// LRange implements cache.Pipeliner. +func (p *pipelineImpl) LRange(ctx context.Context, key string, start int64, stop int64) cache.StringSliceCmd { + return p.p.LRange(ctx, key, start, stop) +} + +// LSet implements cache.Pipeliner. +func (p *pipelineImpl) LSet(ctx context.Context, key string, index int64, value interface{}) cache.StatusCmd { + return p.p.LSet(ctx, key, index, value) +} + +// Pipeline implements cache.Pipeliner. +func (p *pipelineImpl) Pipeline() cache.Pipeliner { + return p +} + +// RPush implements cache.Pipeliner. +func (p *pipelineImpl) RPush(ctx context.Context, key string, values ...interface{}) cache.IntCmd { + return p.p.RPush(ctx, key, values...) +} + +// Set implements cache.Pipeliner. +func (p *pipelineImpl) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) cache.StatusCmd { + return p.p.Set(ctx, key, value, expiration) } diff --git a/backend/infra/impl/checkpoint/redis.go b/backend/infra/impl/checkpoint/redis.go index 059bc907..7724a59e 100644 --- a/backend/infra/impl/checkpoint/redis.go +++ b/backend/infra/impl/checkpoint/redis.go @@ -23,7 +23,6 @@ import ( "time" "github.com/cloudwego/eino/compose" - "github.com/redis/go-redis/v9" "github.com/coze-dev/coze-studio/backend/infra/contract/cache" ) @@ -52,6 +51,6 @@ func (r *redisStore) Set(ctx context.Context, checkPointID string, checkPoint [] return r.client.Set(ctx, fmt.Sprintf(checkpointKeyTpl, checkPointID), checkPoint, checkpointExpire).Err() } -func NewRedisStore(client *redis.Client) compose.CheckPointStore { +func NewRedisStore(client cache.Cmdable) compose.CheckPointStore { return &redisStore{client: client} }