master
parent
6a95f03915
commit
d292f0c23d
51 changed files with 173390 additions and 6 deletions
@ -0,0 +1,70 @@ |
||||
package main |
||||
|
||||
import ( |
||||
_ "embed" |
||||
"flag" |
||||
"fmt" |
||||
"git.diulo.com/mogfee/kit/mysql/ddl" |
||||
"git.diulo.com/mogfee/kit/stringx" |
||||
"html/template" |
||||
"os" |
||||
"os/exec" |
||||
"strings" |
||||
) |
||||
|
||||
//go:embed template.tpl
|
||||
var modelTemplate string |
||||
|
||||
func main() { |
||||
sqlFile := flag.String("f", "", "数据库创建mysqldump文件") |
||||
savePath := flag.String("s", "./model/", "数据库存储路径") |
||||
flag.Parse() |
||||
fmt.Println("sqlfile:", *sqlFile) |
||||
fmt.Println("savePath:", *savePath) |
||||
if *sqlFile == "" { |
||||
return |
||||
} |
||||
os.MkdirAll(*savePath, os.ModePerm) |
||||
err := ddl.Parser(*sqlFile, func(table *ddl.Table) error { |
||||
table.Imports = append(table.Imports, "gorm.io/gorm") |
||||
table.Imports = append(table.Imports, "context") |
||||
//table.Imports = append(table.Imports, "git.diulo.com/mogfee/kit/errors")
|
||||
saveFileName := fmt.Sprintf("%s/%s_gen.go", strings.TrimRight(*savePath, "/"), table.Name) |
||||
tmp, err := template.New("").Funcs(map[string]any{ |
||||
"UpperType": func(str string) string { |
||||
return stringx.Ucfirst(ddl.GoName(str)) |
||||
}, |
||||
"LowerType": func(str string) string { |
||||
if str == "type" { |
||||
str = "vtype" |
||||
} |
||||
return stringx.Lcfirst(ddl.GoName(str)) |
||||
}, |
||||
"UpdateColumn": func(columns []*ddl.TableColumn) string { |
||||
arr := []string{} |
||||
for _, v := range columns { |
||||
arr = append(arr, fmt.Sprintf("%s=?", v.Name)) |
||||
} |
||||
return strings.Join(arr, " and ") |
||||
}, |
||||
}).Parse(modelTemplate) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
//ddl.PrintJson(table)
|
||||
//table.Primary.GoType
|
||||
//fmt.Println(table.Primary)
|
||||
f, err := os.Create(saveFileName) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer f.Close() |
||||
return tmp.Execute(f, table) |
||||
}) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
|
||||
exec.Command("gofmt", "-l", "-w", *savePath+"/..").Run() |
||||
|
||||
} |
@ -0,0 +1,77 @@ |
||||
package model |
||||
{{if .Imports}} |
||||
import ( |
||||
{{range $val:= .Imports}} |
||||
"{{$val}}"{{end}} |
||||
) |
||||
{{end}} |
||||
type {{.Name|UpperType}} struct { |
||||
updates map[string]any |
||||
{{range $val := .Columns}}{{$val.Name|UpperType}} {{$val.GoType}} `db:"{{$val.Name}}"`{{if $val.Comment}}//{{$val.Comment}}{{end}} |
||||
{{end}} |
||||
} |
||||
func (s *{{.Name|UpperType}}) TableName() string { |
||||
return "{{.Name}}" |
||||
} |
||||
func New{{.Name|UpperType}}() *{{.Name|UpperType}} { |
||||
return &{{.Name|UpperType}}{ |
||||
updates: make(map[string]any), |
||||
{{range $val := .Columns}}{{if $val.Default}}{{$val.Name|UpperType}}:{{if (eq $val.GoType "string")}}"{{$val.Default}}"{{else}}{{$val.Default}}{{end}}, |
||||
{{end}}{{end}} |
||||
} |
||||
} |
||||
{{range $val := .Columns}} |
||||
func (s *{{$.Name|UpperType}}) Set{{$val.Name|UpperType}}({{$val.Name|LowerType}} {{$val.GoType}}) { |
||||
s.{{$val.Name|UpperType}} = {{$val.Name|LowerType}} |
||||
s.set("{{$val.Name}}", {{$val.Name|LowerType}}) |
||||
} |
||||
{{end}} |
||||
|
||||
func (s *{{.Name|UpperType}}) set(key string, val any) { |
||||
s.updates[key] = val |
||||
} |
||||
|
||||
func (s *{{.Name|UpperType}}) UpdateColumn() map[string]any { |
||||
return s.updates |
||||
} |
||||
|
||||
|
||||
|
||||
type default{{.Name|UpperType}}DAO struct { |
||||
db *gorm.DB |
||||
} |
||||
func New{{.Name|UpperType}}DAO(db *gorm.DB) *default{{.Name|UpperType}}DAO { |
||||
return &default{{.Name|UpperType}}DAO{ |
||||
db: db, |
||||
} |
||||
} |
||||
|
||||
func (s *default{{.Name|UpperType}}DAO) Insert(ctx context.Context, data *{{.Name|UpperType}}) error { |
||||
return s.db.Create(data).Error |
||||
} |
||||
{{ if .Primary}} |
||||
func (s *default{{.Name|UpperType}}DAO) Update(ctx context.Context, {{.Primary.Name|LowerType}} {{.Primary.GoType}}, updates map[string]any) error { |
||||
return s.db.Model({{.Name|UpperType}}{}).Where("{{.Primary.Name}}=?", {{.Primary.Name|LowerType}}).Updates(updates).Error |
||||
} |
||||
|
||||
|
||||
func (s *default{{.Name|UpperType}}DAO) Delete(ctx context.Context, {{.Primary.Name|LowerType}} {{.Primary.GoType}}) error { |
||||
return s.db.Where("{{.Primary.Name}}=?", {{.Primary.Name|LowerType}}).Delete(&{{.Name|UpperType}}{}).Error |
||||
} |
||||
|
||||
|
||||
func (s *default{{.Name|UpperType}}DAO) FindOne(ctx context.Context, {{.Primary.Name|LowerType}} {{.Primary.GoType}}) (*{{.Name|UpperType}}, error) { |
||||
row := {{.Name|UpperType}}{} |
||||
err := s.db.Where("{{.Primary.Name}}=?", {{.Primary.Name|LowerType}}).Find(&row).Error |
||||
if err!=nil{ |
||||
return nil,err |
||||
} |
||||
return &row, nil |
||||
} |
||||
{{end}} |
||||
|
||||
{{range $val:= .Indexes}} |
||||
func (s *default{{$.Name|UpperType}}DAO) Find{{$val.Name|UpperType}}(ctx context.Context {{range $column:=$val.Columns}},{{$column.Name|LowerType}} {{$column.GoType}}{{end}}) error { |
||||
return s.db.Where("{{$val.Columns|UpdateColumn}}" {{range $column:=$val.Columns}},{{$column.Name|LowerType}}{{end}}).Delete(&{{$.Name|UpperType}}{}).Error |
||||
} |
||||
{{end}} |
@ -0,0 +1,42 @@ |
||||
package errors |
||||
|
||||
import "bytes" |
||||
|
||||
type ( |
||||
BatchError struct { |
||||
errs errorArray |
||||
} |
||||
errorArray []error |
||||
) |
||||
|
||||
func (be *BatchError) Add(errs ...error) { |
||||
for _, err := range errs { |
||||
if err != nil { |
||||
be.errs = append(be.errs, err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (be *BatchError) Err() error { |
||||
switch len(be.errs) { |
||||
case 0: |
||||
return nil |
||||
case 1: |
||||
return be.errs[0] |
||||
default: |
||||
return be.errs |
||||
} |
||||
} |
||||
func (be *BatchError) NotNil() bool { |
||||
return len(be.errs) > 0 |
||||
} |
||||
func (ea errorArray) Error() string { |
||||
var buf bytes.Buffer |
||||
for i := range ea { |
||||
if i > 0 { |
||||
buf.WriteByte('\n') |
||||
} |
||||
buf.WriteString(ea[i].Error()) |
||||
} |
||||
return buf.String() |
||||
} |
@ -0,0 +1,7 @@ |
||||
package errors |
||||
|
||||
import ( |
||||
"gorm.io/gorm" |
||||
) |
||||
|
||||
var ErrNotFound = gorm.ErrRecordNotFound |
@ -0,0 +1,44 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"database/sql" |
||||
"fmt" |
||||
_ "github.com/go-sql-driver/mysql" |
||||
"time" |
||||
) |
||||
|
||||
type Account struct { |
||||
Id int64 `db:"id"` |
||||
UserId int64 `db:"user_id"` |
||||
UserType int64 `db:"user_type"` // 1 user 2 employer
|
||||
Name string `db:"name"` |
||||
Email string `db:"email"` |
||||
Password string `db:"password"` |
||||
Phone string `db:"phone"` // +80-18010489927
|
||||
WechatUniqueId string `db:"wechat_unique_id"` |
||||
} |
||||
|
||||
func main() { |
||||
dsn := "root:123456@tcp(127.0.0.1:3306)/test_gozero" |
||||
db, err := sql.Open("mysql", dsn) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
db.SetConnMaxLifetime(time.Minute * 3) |
||||
db.SetMaxOpenConns(10) |
||||
db.SetMaxIdleConns(10) |
||||
if err = db.Ping(); err != nil { |
||||
panic(err) |
||||
} |
||||
|
||||
smt, err := db.Prepare("insert into account (user_id)values (?)") |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
res, err := smt.Exec(1, 3) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
fmt.Println(res.LastInsertId()) |
||||
fmt.Println(res.RowsAffected()) |
||||
} |
@ -0,0 +1,71 @@ |
||||
package lang |
||||
|
||||
import ( |
||||
"fmt" |
||||
"reflect" |
||||
"strconv" |
||||
) |
||||
|
||||
var Placeholder PlaceholderType |
||||
|
||||
type ( |
||||
AnyType = any |
||||
PlaceholderType struct{} |
||||
) |
||||
|
||||
func Repr(v any) string { |
||||
if v == nil { |
||||
return "" |
||||
} |
||||
switch vt := v.(type) { |
||||
case fmt.Stringer: |
||||
return vt.String() |
||||
} |
||||
val := reflect.ValueOf(v) |
||||
|
||||
if val.Kind() == reflect.Ptr && !val.IsNil() { |
||||
val = val.Elem() |
||||
} |
||||
return reprOfValue(val) |
||||
} |
||||
|
||||
func reprOfValue(val reflect.Value) string { |
||||
switch vt := val.Interface().(type) { |
||||
case bool: |
||||
return strconv.FormatBool(vt) |
||||
case error: |
||||
return vt.Error() |
||||
case float32: |
||||
return strconv.FormatFloat(float64(vt), 'f', -1, 32) |
||||
case float64: |
||||
return strconv.FormatFloat(vt, 'f', -1, 64) |
||||
case fmt.Stringer: |
||||
return vt.String() |
||||
case int: |
||||
return strconv.Itoa(vt) |
||||
case int8: |
||||
return strconv.Itoa(int(vt)) |
||||
case int16: |
||||
return strconv.Itoa(int(vt)) |
||||
case int32: |
||||
return strconv.Itoa(int(vt)) |
||||
case int64: |
||||
return strconv.FormatInt(vt, 10) |
||||
case string: |
||||
return vt |
||||
case uint: |
||||
return strconv.FormatUint(uint64(vt), 10) |
||||
case uint8: |
||||
return strconv.FormatUint(uint64(vt), 10) |
||||
case uint16: |
||||
return strconv.FormatUint(uint64(vt), 10) |
||||
case uint32: |
||||
return strconv.FormatUint(uint64(vt), 10) |
||||
case uint64: |
||||
return strconv.FormatUint(vt, 10) |
||||
case []byte: |
||||
return string(vt) |
||||
default: |
||||
return fmt.Sprint(val.Interface()) |
||||
} |
||||
} |
@ -0,0 +1,52 @@ |
||||
package mapping |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"git.diulo.com/mogfee/kit/lang" |
||||
"reflect" |
||||
"sync" |
||||
) |
||||
|
||||
const ( |
||||
defaultOption = "default" |
||||
envOption = "env" |
||||
inheritOption = "inherit" |
||||
stringOption = "string" |
||||
optionalOption = "optional" |
||||
optionsOption = "options" |
||||
rangeOption = "range" |
||||
optionSeparator = "|" |
||||
equalToken = "=" |
||||
escapeChar = '\\' |
||||
leftBracket = '(' |
||||
rightBracket = ')' |
||||
leftSquareBracket = '[' |
||||
rightSquareBracket = ']' |
||||
segmentSeparator = ',' |
||||
) |
||||
|
||||
var ( |
||||
errUnsupportedType = errors.New("unsupported type on setting field value") |
||||
errNumberRange = errors.New("wrong number range setting") |
||||
//optionsCache = make(map[string]optionsCacheValue)
|
||||
cacheLock sync.RWMutex |
||||
//structRequiredCache = make(map[reflect.Type]requiredCacheValue)
|
||||
structCacheLock sync.RWMutex |
||||
) |
||||
|
||||
func Deref(t reflect.Type) reflect.Type { |
||||
for t.Kind() == reflect.Ptr { |
||||
t = t.Elem() |
||||
} |
||||
return t |
||||
} |
||||
func ValidatePtr(v *reflect.Value) error { |
||||
if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() { |
||||
return fmt.Errorf("not a valid pointer: %v", v) |
||||
} |
||||
return nil |
||||
} |
||||
func Repr(v any) string { |
||||
return lang.Repr(v) |
||||
} |
@ -0,0 +1,9 @@ |
||||
https://www.antlr.org/ |
||||
|
||||
http://lab.antlr.org/ |
||||
|
||||
https://github.com/antlr/antlr4 |
||||
|
||||
github.com/antlr4-go/antlr |
||||
|
||||
antlr4 -visitor -Dlanguage=Go -o ../gen MySql* |
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,14 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"fmt" |
||||
"git.diulo.com/mogfee/kit/mysql/ddl" |
||||
) |
||||
|
||||
func main() { |
||||
err := ddl.Parser("/Users/mogfee/web/kit/cmd/mysql-kit/test_gozero.sql", func(table *ddl.Table) error { |
||||
ddl.PrintJson(table) |
||||
return nil |
||||
}) |
||||
fmt.Println(err) |
||||
} |
@ -0,0 +1,29 @@ |
||||
package ddl |
||||
|
||||
import ( |
||||
"github.com/antlr4-go/antlr/v4" |
||||
"unicode" |
||||
) |
||||
|
||||
type CaseChangingStream struct { |
||||
antlr.CharStream |
||||
upper bool |
||||
} |
||||
|
||||
func newCaseChangingStream(in antlr.CharStream, upper bool) *CaseChangingStream { |
||||
return &CaseChangingStream{ |
||||
in, |
||||
upper, |
||||
} |
||||
} |
||||
func (is *CaseChangingStream) LA(offset int) int { |
||||
in := is.CharStream.LA(offset) |
||||
if in < 0 { |
||||
return in |
||||
} |
||||
if is.upper { |
||||
return int(unicode.ToUpper(rune(in))) |
||||
} else { |
||||
return int(unicode.ToLower(rune(in))) |
||||
} |
||||
} |
@ -0,0 +1,111 @@ |
||||
package ddl |
||||
|
||||
import ( |
||||
"fmt" |
||||
"git.diulo.com/mogfee/kit/mysql/parser" |
||||
"github.com/antlr4-go/antlr/v4" |
||||
) |
||||
|
||||
func (*visitor) visitorNotNull(ctx *parser.NullColumnConstraintContext) bool { |
||||
if ret, ok := ctx.NullNotnull().(*parser.NullNotnullContext); ok { |
||||
if ret.NOT() != nil { |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
} |
||||
func (v *visitor) getTableColumn(tx *parser.ColumnDeclarationContext) *TableColumn { |
||||
column := &TableColumn{} |
||||
iDefinitionContext := tx.ColumnDefinition() |
||||
column.Name = trimText(tx.FullColumnName().Uid().GetText()) |
||||
fmt.Println(column.Name) |
||||
defineContext, ok := iDefinitionContext.(*parser.ColumnDefinitionContext) |
||||
for _, v := range defineContext.DataType().GetChildren() { |
||||
if aa, ok := v.GetPayload().(*antlr.CommonToken); ok { |
||||
if aa.GetText() == "COLLATE" { |
||||
continue |
||||
} |
||||
if aa.GetText() == "unsigned" { |
||||
continue |
||||
} |
||||
if aa.GetText() == "zerofill" { |
||||
continue |
||||
} |
||||
column.Type = aa.GetText() |
||||
} |
||||
} |
||||
if ok { |
||||
for _, e := range defineContext.AllColumnConstraint() { |
||||
switch t := e.(type) { |
||||
case *parser.NullColumnConstraintContext: |
||||
column.IsNotNull = v.visitorNotNull(t) |
||||
case *parser.DefaultColumnConstraintContext: |
||||
defaultVal := trimText(t.DefaultValue().GetText()) |
||||
if defaultVal == "NULL" { |
||||
column.IsDefault = false |
||||
} else { |
||||
column.IsDefault = true |
||||
column.Default = defaultVal |
||||
} |
||||
case *parser.AutoIncrementColumnConstraintContext: |
||||
//if t.AUTO_INCREMENT() != nil {
|
||||
// column.IsAutoIncrement = true
|
||||
//}
|
||||
case *parser.CommentColumnConstraintContext: |
||||
column.Comment = trimText(t.STRING_LITERAL().GetText()) |
||||
default: |
||||
fmt.Printf("###=======%T\n", e) |
||||
} |
||||
} |
||||
} |
||||
return column |
||||
} |
||||
func (v *visitor) getPrimaryOrUniqueKey(tx *parser.ConstraintDeclarationContext) *TableIndex { |
||||
index := &TableIndex{} |
||||
if primary, ok := tx.TableConstraint().(*parser.PrimaryKeyTableConstraintContext); ok { |
||||
index.IsPrimary = true |
||||
if primary.PRIMARY() != nil { |
||||
aa := primary.IndexColumnNames().(*parser.IndexColumnNamesContext) |
||||
for _, a := range aa.AllIndexColumnName() { |
||||
index.Name = trimText(a.Uid().GetText()) |
||||
index.ColumnsStr = []string{trimText(a.Uid().GetText())} |
||||
} |
||||
} |
||||
} else if unique, ok := tx.TableConstraint().(*parser.UniqueKeyTableConstraintContext); ok { |
||||
index.IsUnique = true |
||||
if len(unique.AllUid()) > 0 { |
||||
index.Name = trimText(unique.AllUid()[0].GetText()) |
||||
} |
||||
for _, aa2 := range unique.IndexColumnNames().AllIndexColumnName() { |
||||
index.ColumnsStr = append(index.ColumnsStr, trimText(aa2.GetText())) |
||||
} |
||||
} else { |
||||
fmt.Printf("%T\n", tx) |
||||
} |
||||
return index |
||||
} |
||||
func (v *visitor) getIndex(tx *parser.IndexDeclarationContext) *TableIndex { |
||||
index := &TableIndex{} |
||||
if a, ok := tx.IndexColumnDefinition().(*parser.SimpleIndexDeclarationContext); ok { |
||||
index.Name = trimText(a.Uid().GetText()) |
||||
index.ColumnsStr = []string{} |
||||
for _, aa := range a.IndexColumnNames().AllIndexColumnName() { |
||||
index.ColumnsStr = append(index.ColumnsStr, trimText(aa.Uid().GetText())) |
||||
} |
||||
} else { |
||||
fmt.Printf("%T\n", tx.IndexColumnDefinition()) |
||||
} |
||||
return index |
||||
} |
||||
func (v *visitor) getIndexColumn(table *Table, index *TableIndex) *TableIndex { |
||||
columns := make([]*TableColumn, 0, len(index.ColumnsStr)) |
||||
for _, v := range index.ColumnsStr { |
||||
for _, tc := range table.Columns { |
||||
if tc.Name == v { |
||||
columns = append(columns, tc) |
||||
} |
||||
} |
||||
} |
||||
index.Columns = columns |
||||
return index |
||||
} |
@ -0,0 +1,77 @@ |
||||
package ddl |
||||
|
||||
import ( |
||||
"fmt" |
||||
"git.diulo.com/mogfee/kit/mysql/parser" |
||||
"strings" |
||||
) |
||||
|
||||
func (v *visitor) visitCreateTable(ctx parser.ICreateTableContext) any { |
||||
//v.trace("visitCreateTable")
|
||||
switch tx := ctx.(type) { |
||||
case *parser.CopyCreateTableContext: |
||||
v.panicWithExpr(tx.GetStart(), |
||||
"Unsupported creating a table by copying from another table", |
||||
) |
||||
case *parser.QueryCreateTableContext: |
||||
v.panicWithExpr(tx.GetStart(), |
||||
"Unsupported creating a table by querying from another table", |
||||
) |
||||
case *parser.ColumnCreateTableContext: |
||||
return v.visitColumnCreateTable(tx) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (v *visitor) visitColumnCreateTable(ctx *parser.ColumnCreateTableContext) *Table { |
||||
//v.trace("visitColumnCreateTable")
|
||||
|
||||
table := &Table{ |
||||
Name: trimText(ctx.TableName().GetText()), |
||||
} |
||||
fmt.Println("========================", table.Name) |
||||
if ctx.CreateDefinitions() != nil { |
||||
if defctx, ok := ctx.CreateDefinitions().(*parser.CreateDefinitionsContext); ok { |
||||
v.visitCreateDefinitions(defctx, table) |
||||
} |
||||
} |
||||
return table |
||||
} |
||||
func (v *visitor) visitCreateDefinitions(ctx *parser.CreateDefinitionsContext, table *Table) { |
||||
//v.trace("visitCreateDefinitions")
|
||||
for _, e := range ctx.AllCreateDefinition() { |
||||
v.VisitCreateDefinition(e, table) |
||||
} |
||||
} |
||||
func (v *visitor) VisitCreateDefinition(ctx parser.ICreateDefinitionContext, table *Table) { |
||||
//v.trace("VisitCreateDefinition")
|
||||
switch tx := ctx.(type) { |
||||
case *parser.ColumnDeclarationContext: |
||||
table.Columns = append(table.Columns, v.getTableColumn(tx)) |
||||
case *parser.ConstraintDeclarationContext: |
||||
data := v.getPrimaryOrUniqueKey(tx) |
||||
if data.IsPrimary { |
||||
for _, col := range table.Columns { |
||||
if col.Name == data.Name { |
||||
table.Primary = col |
||||
col.IsPrimary = true |
||||
} |
||||
} |
||||
} |
||||
if !data.IsPrimary { |
||||
table.Indexes = append(table.Indexes, v.getIndexColumn(table, data)) |
||||
} |
||||
case *parser.IndexDeclarationContext: |
||||
//fmt.Println("index", tx.GetText())
|
||||
table.Indexes = append(table.Indexes, v.getIndexColumn(table, v.getIndex(tx))) |
||||
default: |
||||
fmt.Printf("%T\n", tx) |
||||
} |
||||
} |
||||
|
||||
func trimText(str string) string { |
||||
str = strings.Trim(str, "`") |
||||
str = strings.Trim(str, "'") |
||||
replacer := strings.NewReplacer("\r", "", "\n", "") |
||||
return replacer.Replace(str) |
||||
} |
@ -0,0 +1,107 @@ |
||||
package ddl |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"git.diulo.com/mogfee/kit/mysql/parser" |
||||
"github.com/antlr4-go/antlr/v4" |
||||
"os" |
||||
"strings" |
||||
) |
||||
|
||||
func Parser(fileName string, fun func(table *Table) error) error { |
||||
body, err := os.ReadFile(fileName) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
inputStream := antlr.NewInputStream(string(body)) |
||||
caseStream := newCaseChangingStream(inputStream, true) |
||||
lexer := parser.NewMySqlLexer(caseStream) |
||||
lexer.RemoveErrorListeners() |
||||
tokens := antlr.NewCommonTokenStream(lexer, antlr.LexerDefaultTokenChannel) |
||||
mysqlParser := parser.NewMySqlParser(tokens) |
||||
mysqlParser.RemoveErrorListeners() |
||||
vis := visitor{} |
||||
data := mysqlParser.Root().Accept(&vis) |
||||
if createTables, ok := data.([]*Table); ok { |
||||
for _, table := range createTables { |
||||
if err = fun(table); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
type visitor struct { |
||||
parser.BaseMySqlParserVisitor |
||||
} |
||||
|
||||
func (v *visitor) panicWithExpr(expr any, msg string) { |
||||
panic(msg) |
||||
} |
||||
func (v *visitor) trace(msg ...interface{}) { |
||||
fmt.Println("Visit Trace: " + fmt.Sprint(msg...)) |
||||
} |
||||
func (v *visitor) VisitRoot(ctx *parser.RootContext) interface{} { |
||||
//v.trace("VisitRoot")
|
||||
if ctx.SqlStatements() != nil { |
||||
return ctx.SqlStatements().Accept(v) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (v *visitor) VisitSqlStatements(ctx *parser.SqlStatementsContext) interface{} { |
||||
//v.trace("VisitSqlStatements")
|
||||
createTables := make([]*Table, 0) |
||||
for _, e := range ctx.AllSqlStatement() { |
||||
ret := e.Accept(v) |
||||
if ret == nil { |
||||
continue |
||||
} |
||||
if table, ok := ret.(*Table); ok { |
||||
table.Imports = table.GetImports() |
||||
for _, v := range table.Columns { |
||||
v.GoType = v.GetTypeStr() |
||||
v.Default = v.GetDefault() |
||||
} |
||||
|
||||
createTables = append(createTables, table) |
||||
} |
||||
} |
||||
return createTables |
||||
} |
||||
|
||||
func (v *visitor) VisitSqlStatement(ctx *parser.SqlStatementContext) interface{} { |
||||
//v.trace("VisitSqlStatement")
|
||||
if ctx.DdlStatement() != nil { |
||||
return ctx.DdlStatement().Accept(v) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (v *visitor) VisitDdlStatement(ctx *parser.DdlStatementContext) interface{} { |
||||
//v.trace("VisitDdlStatement")
|
||||
if ctx.CreateTable() != nil { |
||||
return v.visitCreateTable(ctx.CreateTable()) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func PrintJson(dat any) { |
||||
b, _ := json.MarshalIndent(dat, " ", " ") |
||||
fmt.Println(string(b)) |
||||
} |
||||
|
||||
func GoName(str string) string { |
||||
//数字开头
|
||||
if str[0] >= 48 && str[0] <= 57 { |
||||
str = "a" + str |
||||
} |
||||
arr := strings.Split(str, "_") |
||||
newArr := []string{} |
||||
for _, v := range arr { |
||||
newArr = append(newArr, strings.Title(v)) |
||||
} |
||||
return strings.Join(newArr, "") |
||||
} |
@ -0,0 +1,85 @@ |
||||
package ddl |
||||
|
||||
type TableColumn struct { |
||||
Name string |
||||
Type string |
||||
GoType string |
||||
Default string |
||||
Comment string |
||||
IsPrimary bool |
||||
IsNotNull bool |
||||
IsDefault bool |
||||
} |
||||
|
||||
func (t *TableColumn) GetDefault() string { |
||||
if t.Default == "0" { |
||||
return "" |
||||
} |
||||
if t.Default == "0000-00-00 00:00:00" { |
||||
return "" |
||||
} |
||||
return t.Default |
||||
} |
||||
func (t *TableColumn) GetTypeStr() string { |
||||
switch t.Type { |
||||
case "bigint": |
||||
return "int64" |
||||
case "datetime", "timetime", "date", "timestamp": |
||||
return "time.Time" |
||||
case "tinyint", "smallint", "int": |
||||
return "int32" |
||||
case "varchar", "char", "text", "longtext": |
||||
return "string" |
||||
case "float": |
||||
return "float32" |
||||
case "decimal": |
||||
return "float64" |
||||
} |
||||
return t.Type |
||||
} |
||||
|
||||
func (t *TableColumn) GetTypeImport() string { |
||||
switch t.Type { |
||||
case "datetime", "timetime", "date", "timestamp": |
||||
return "time" |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
type TableIndex struct { |
||||
Name string |
||||
IsPrimary bool |
||||
IsUnique bool |
||||
ColumnsStr []string |
||||
Columns []*TableColumn |
||||
} |
||||
|
||||
func (t *TableIndex) GoColumns() []string { |
||||
newArr := []string{} |
||||
for _, v := range t.ColumnsStr { |
||||
newArr = append(newArr, GoName(v)) |
||||
} |
||||
return newArr |
||||
} |
||||
|
||||
type Table struct { |
||||
Name string |
||||
Imports []string |
||||
Columns []*TableColumn |
||||
Indexes []*TableIndex |
||||
Primary *TableColumn |
||||
} |
||||
|
||||
func (t *Table) GetImports() []string { |
||||
arr := map[string]string{} |
||||
for _, v := range t.Columns { |
||||
if imp := v.GetTypeImport(); imp != "" { |
||||
arr[imp] = imp |
||||
} |
||||
} |
||||
newArr := []string{} |
||||
for _, v := range arr { |
||||
newArr = append(newArr, v) |
||||
} |
||||
return newArr |
||||
} |
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,10 @@ |
||||
package sqlx |
||||
|
||||
import ( |
||||
"errors" |
||||
) |
||||
|
||||
var ( |
||||
errCantNestTx = errors.New("cannot nest transactions") |
||||
errNoRawDBFromTx = errors.New("cannot get raw db from transaction") |
||||
) |
@ -0,0 +1,30 @@ |
||||
package sqlx |
||||
|
||||
import "github.com/go-sql-driver/mysql" |
||||
|
||||
const ( |
||||
mysqlDriverName = "mysql" |
||||
duplicateEntryCode uint16 = 1062 |
||||
) |
||||
|
||||
func NewMysql(datasource string, opts ...SqlOption) SqlConn { |
||||
opts = append(opts, func(conn *commonSqlConn) { |
||||
conn.accept = func(err error) bool { |
||||
if err == nil { |
||||
return true |
||||
} |
||||
myerr, ok := err.(*mysql.MySQLError) |
||||
if !ok { |
||||
return false |
||||
} |
||||
switch myerr.Number { |
||||
case duplicateEntryCode: |
||||
return true |
||||
default: |
||||
return false |
||||
} |
||||
} |
||||
}) |
||||
|
||||
return NewSqlConn(mysqlDriverName, datasource, opts...) |
||||
} |
@ -0,0 +1,229 @@ |
||||
package sqlx |
||||
|
||||
import ( |
||||
"errors" |
||||
"git.diulo.com/mogfee/kit/mapping" |
||||
"reflect" |
||||
"strings" |
||||
) |
||||
|
||||
const tagName = "db" |
||||
|
||||
var ( |
||||
ErrNotMatchDestination = errors.New("not matching destination to scan") |
||||
ErrNotReadableValue = errors.New("value not addressable or interfaceable") |
||||
ErrNotSettable = errors.New("passed in variable is not settable") |
||||
ErrUnsupportedValueType = errors.New("unsupported unmarshal type") |
||||
) |
||||
|
||||
type rowsScanner interface { |
||||
Columns() ([]string, error) |
||||
Err() error |
||||
Next() bool |
||||
Scan(v ...any) error |
||||
} |
||||
|
||||
func unmarshalRow(v any, scanner rowsScanner, strict bool) error { |
||||
if !scanner.Next() { |
||||
if err := scanner.Err(); err != nil { |
||||
return err |
||||
} |
||||
return ErrNotFound |
||||
} |
||||
rv := reflect.ValueOf(v) |
||||
if err := mapping.ValidatePtr(&rv); err != nil { |
||||
return err |
||||
} |
||||
rte := reflect.TypeOf(v).Elem() |
||||
rve := rv.Elem() |
||||
switch rte.Kind() { |
||||
case reflect.Bool, |
||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, |
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, |
||||
reflect.Float32, reflect.Float64, |
||||
reflect.String: |
||||
if rve.CanSet() { |
||||
return scanner.Scan(v) |
||||
} |
||||
return ErrNotSettable |
||||
case reflect.Struct: |
||||
columns, err := scanner.Columns() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
values, err := mapStructFieldsIntoSlice(rve, columns, strict) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return scanner.Scan(values...) |
||||
default: |
||||
return ErrUnsupportedValueType |
||||
} |
||||
} |
||||
func unmarshalRows(v any, scanner rowsScanner, strict bool) error { |
||||
rv := reflect.ValueOf(v) |
||||
if err := mapping.ValidatePtr(&rv); err != nil { |
||||
return err |
||||
} |
||||
rt := reflect.TypeOf(v) |
||||
rte := rt.Elem() |
||||
rve := rv.Elem() |
||||
switch rte.Kind() { |
||||
case reflect.Slice: |
||||
if rve.CanSet() { |
||||
ptr := rte.Elem().Kind() == reflect.Ptr |
||||
appendFn := func(item reflect.Value) { |
||||
if ptr { |
||||
rve.Set(reflect.Append(rve, item)) |
||||
} else { |
||||
rve.Set(reflect.Append(rve, reflect.Indirect(item))) |
||||
} |
||||
} |
||||
fillFn := func(value any) error { |
||||
if rve.CanSet() { |
||||
if err := scanner.Scan(value); err != nil { |
||||
return err |
||||
} |
||||
appendFn(reflect.ValueOf(value)) |
||||
return nil |
||||
} |
||||
return ErrNotSettable |
||||
} |
||||
base := mapping.Deref(rte.Elem()) |
||||
switch base.Kind() { |
||||
case reflect.Bool, |
||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, |
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, |
||||
reflect.Float32, reflect.Float64, |
||||
reflect.String: |
||||
for scanner.Next() { |
||||
value := reflect.New(base) |
||||
if err := fillFn(value.Interface()); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
case reflect.Struct: |
||||
columns, err := scanner.Columns() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
for scanner.Next() { |
||||
value := reflect.New(base) |
||||
values, err := mapStructFieldsIntoSlice(value, columns, strict) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if err = scanner.Scan(values...); err != nil { |
||||
return err |
||||
} |
||||
appendFn(value) |
||||
} |
||||
default: |
||||
return ErrUnsupportedValueType |
||||
} |
||||
return nil |
||||
} |
||||
return ErrNotSettable |
||||
default: |
||||
return ErrUnsupportedValueType |
||||
} |
||||
} |
||||
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) { |
||||
fields := unwrapFields(v) |
||||
if strict && len(columns) < len(fields) { |
||||
return nil, ErrNotMatchDestination |
||||
} |
||||
taggedMap, err := getTaggedFieldValueMap(v) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
values := make([]any, len(columns)) |
||||
if len(taggedMap) == 0 { |
||||
for i := 0; i < len(values); i++ { |
||||
valueField := fields[i] |
||||
switch valueField.Kind() { |
||||
case reflect.Ptr: |
||||
if !valueField.CanInterface() { |
||||
return nil, ErrNotReadableValue |
||||
} |
||||
if valueField.IsNil() { |
||||
baseValueType := mapping.Deref(valueField.Type()) |
||||
valueField.Set(reflect.New(baseValueType)) |
||||
} |
||||
values[i] = valueField.Interface() |
||||
default: |
||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() { |
||||
return nil, ErrNotReadableValue |
||||
} |
||||
values[i] = valueField.Addr().Interface() |
||||
} |
||||
} |
||||
} else { |
||||
for i, column := range columns { |
||||
if tagged, ok := taggedMap[column]; ok { |
||||
values[i] = tagged |
||||
} else { |
||||
var anoymouse any |
||||
values[i] = &anoymouse |
||||
} |
||||
} |
||||
} |
||||
return values, nil |
||||
} |
||||
func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) { |
||||
rt := mapping.Deref(v.Type()) |
||||
size := rt.NumField() |
||||
result := make(map[string]any, size) |
||||
for i := 0; i < size; i++ { |
||||
key := parseTagName(rt.Field(i)) |
||||
if len(key) == 0 { |
||||
return nil, nil |
||||
} |
||||
valueField := reflect.Indirect(v).Field(i) |
||||
switch valueField.Kind() { |
||||
case reflect.Ptr: |
||||
if !valueField.CanInterface() { |
||||
return nil, ErrNotReadableValue |
||||
} |
||||
if valueField.IsNil() { |
||||
baseValueType := mapping.Deref(valueField.Type()) |
||||
valueField.Set(reflect.New(baseValueType)) |
||||
} |
||||
result[key] = valueField.Interface() |
||||
default: |
||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() { |
||||
return nil, ErrNotReadableValue |
||||
} |
||||
result[key] = valueField.Addr().Interface() |
||||
} |
||||
} |
||||
return result, nil |
||||
} |
||||
func parseTagName(field reflect.StructField) string { |
||||
key := field.Tag.Get(tagName) |
||||
if len(key) == 0 { |
||||
return "" |
||||
} |
||||
options := strings.Split(key, ",") |
||||
return options[0] |
||||
} |
||||
func unwrapFields(v reflect.Value) []reflect.Value { |
||||
var fields []reflect.Value |
||||
indirect := reflect.Indirect(v) |
||||
for i := 0; i < indirect.NumField(); i++ { |
||||
child := indirect.Field(i) |
||||
if child.Kind() == reflect.Ptr && child.IsNil() { |
||||
baseValueType := mapping.Deref(child.Type()) |
||||
child.Set(reflect.New(baseValueType)) |
||||
} |
||||
child = reflect.Indirect(child) |
||||
childType := indirect.Type().Field(i) |
||||
if child.Kind() == reflect.Struct && childType.Anonymous { |
||||
fields = append(fields, unwrapFields(child)...) |
||||
} else { |
||||
fields = append(fields, child) |
||||
} |
||||
} |
||||
return fields |
||||
} |
@ -0,0 +1,315 @@ |
||||
package sqlx |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql" |
||||
) |
||||
|
||||
const spanName = "sql" |
||||
|
||||
var ErrNotFound = sql.ErrNoRows |
||||
|
||||
type ( |
||||
Session interface { |
||||
Exec(query string, args ...any) (sql.Result, error) |
||||
ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) |
||||
Prepare(query string) (StmtSession, error) |
||||
PrepareCtx(ctx context.Context, query string) (StmtSession, error) |
||||
QueryRow(v any, query string, args ...any) error |
||||
QueryRowCtx(ctx context.Context, v any, query string, args ...any) error |
||||
QueryRowPartial(v any, query string, args ...any) error |
||||
QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error |
||||
QueryRows(v any, query string, args ...any) error |
||||
QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error |
||||
QueryRowsPartial(v any, query string, args ...any) error |
||||
QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error |
||||
} |
||||
SqlConn interface { |
||||
Session |
||||
RawDB() (*sql.DB, error) |
||||
Transact(fn func(Session) error) error |
||||
TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error |
||||
} |
||||
SqlOption func(*commonSqlConn) |
||||
StmtSession interface { |
||||
Close() error |
||||
Exec(args ...any) (sql.Result, error) |
||||
ExecCtx(ctx context.Context, args ...any) (sql.Result, error) |
||||
QueryRow(v any, args ...any) error |
||||
QueryRowCtx(ctx context.Context, v any, args ...any) error |
||||
QueryRowPartial(v any, args ...any) error |
||||
QueryRowPartialCtx(ctx context.Context, v any, args ...any) error |
||||
QueryRows(v any, args ...any) error |
||||
QueryRowsCtx(ctx context.Context, v any, args ...any) error |
||||
QueryRowsPartial(v any, args ...any) error |
||||
QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error |
||||
} |
||||
|
||||
commonSqlConn struct { |
||||
connProv connProvider |
||||
onError func(ctx context.Context, err error) |
||||
accept func(error) bool |
||||
beginTx beginnable |
||||
} |
||||
connProvider func() (db *sql.DB, err error) |
||||
|
||||
sessionConn interface { |
||||
Exec(query string, args ...any) (sql.Result, error) |
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) |
||||
Query(query string, args ...any) (*sql.Rows, error) |
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) |
||||
} |
||||
statement struct { |
||||
query string |
||||
stmt *sql.Stmt |
||||
} |
||||
stmtConn interface { |
||||
Exec(args ...any) (sql.Result, error) |
||||
ExecContext(ctx context.Context, args ...any) (sql.Result, error) |
||||
Query(args ...any) (*sql.Rows, error) |
||||
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) |
||||
} |
||||
) |
||||
|
||||
func NewSqlConn(driverName string, datasource string, opts ...SqlOption) SqlConn { |
||||
conn := &commonSqlConn{ |
||||
connProv: func() (db *sql.DB, err error) { |
||||
return getSqlConn(driverName, datasource) |
||||
}, |
||||
onError: func(ctx context.Context, err error) { |
||||
logInstanceError(ctx, datasource, err) |
||||
}, |
||||
beginTx: begin, |
||||
} |
||||
for _, opt := range opts { |
||||
opt(conn) |
||||
} |
||||
return conn |
||||
} |
||||
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn { |
||||
conn := &commonSqlConn{ |
||||
connProv: func() (db *sql.DB, err error) { |
||||
return db, nil |
||||
}, |
||||
onError: func(ctx context.Context, err error) { |
||||
//logx.WithContext(ctx).Errorf("Error on getting sql instance: %v", err)
|
||||
}, |
||||
beginTx: begin, |
||||
} |
||||
for _, opt := range opts { |
||||
opt(conn) |
||||
} |
||||
return conn |
||||
} |
||||
func (db *commonSqlConn) Exec(query string, args ...any) (sql.Result, error) { |
||||
return db.ExecCtx(context.Background(), query, args...) |
||||
} |
||||
|
||||
func (db *commonSqlConn) ExecCtx(ctx context.Context, query string, args ...any) (result sql.Result, err error) { |
||||
ctx, span := startSpan(ctx, "Exec") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
conn, err := db.connProv() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
result, err = exec(ctx, conn, query, args...) |
||||
return result, err |
||||
} |
||||
|
||||
func (db *commonSqlConn) Prepare(query string) (StmtSession, error) { |
||||
return db.PrepareCtx(context.Background(), query) |
||||
} |
||||
|
||||
func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) { |
||||
ctx, span := startSpan(ctx, "Prepare") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
var conn *sql.DB |
||||
conn, err = db.connProv() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
var st *sql.Stmt |
||||
st, err = conn.PrepareContext(ctx, query) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
stmt = statement{ |
||||
stmt: st, |
||||
query: query, |
||||
} |
||||
|
||||
return |
||||
} |
||||
|
||||
func (db *commonSqlConn) QueryRow(v any, query string, args ...any) error { |
||||
return db.QueryRowCtx(context.Background(), v, query, args...) |
||||
} |
||||
|
||||
func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, query string, args ...any) (err error) { |
||||
ctx, span := startSpan(ctx, "QueryRow") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
|
||||
err = db.queryRows(ctx, func(rows *sql.Rows) error { |
||||
return unmarshalRow(v, rows, true) |
||||
}, query, args...) |
||||
return |
||||
} |
||||
func (db *commonSqlConn) QueryRowPartial(v any, query string, args ...any) error { |
||||
return db.QueryRowPartialCtx(context.Background(), v, query, args...) |
||||
} |
||||
|
||||
func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) (err error) { |
||||
ctx, span := startSpan(ctx, "QueryRowPartial") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
err = db.queryRows(ctx, func(rows *sql.Rows) error { |
||||
return unmarshalRow(v, rows, false) |
||||
}, query, args...) |
||||
return |
||||
} |
||||
|
||||
func (db *commonSqlConn) QueryRows(v any, query string, args ...any) error { |
||||
return db.QueryRowsCtx(context.Background(), v, query, args...) |
||||
} |
||||
|
||||
func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) (err error) { |
||||
ctx, span := startSpan(ctx, "QueryRows") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
err = db.queryRows(ctx, func(rows *sql.Rows) error { |
||||
return unmarshalRows(v, rows, true) |
||||
}, query, args...) |
||||
return err |
||||
} |
||||
|
||||
func (db *commonSqlConn) QueryRowsPartial(v any, query string, args ...any) error { |
||||
return db.QueryRowPartialCtx(context.Background(), v, query, args...) |
||||
} |
||||
|
||||
func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) (err error) { |
||||
ctx, span := startSpan(ctx, "QueryRowsPartial") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
err = db.queryRows(ctx, func(rows *sql.Rows) error { |
||||
return unmarshalRows(v, rows, false) |
||||
}, query, args...) |
||||
return |
||||
} |
||||
|
||||
func (db *commonSqlConn) RawDB() (*sql.DB, error) { |
||||
return db.connProv() |
||||
} |
||||
|
||||
func (db *commonSqlConn) Transact(fn func(Session) error) error { |
||||
return db.TransactCtx(context.Background(), func(ctx context.Context, session Session) error { |
||||
return fn(session) |
||||
}) |
||||
} |
||||
|
||||
func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) { |
||||
ctx, span := startSpan(ctx, "QueryRowsPartial") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
|
||||
err = transact(ctx, db, db.beginTx, fn) |
||||
return |
||||
} |
||||
func (db *commonSqlConn) acceptable(err error) bool { |
||||
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled |
||||
if db.accept == nil { |
||||
return ok |
||||
} |
||||
return ok || db.accept(err) |
||||
} |
||||
|
||||
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(rows *sql.Rows) error, q string, args ...any) (err error) { |
||||
conn, err := db.connProv() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return query(ctx, conn, func(rows *sql.Rows) error { |
||||
return scanner(rows) |
||||
}, q, args...) |
||||
} |
||||
|
||||
func (s statement) Close() error { |
||||
return s.stmt.Close() |
||||
} |
||||
|
||||
func (s statement) Exec(args ...any) (sql.Result, error) { |
||||
return s.ExecCtx(context.Background(), args...) |
||||
} |
||||
|
||||
func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) { |
||||
ctx, span := startSpan(ctx, "Exec") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
return execStmt(ctx, s.stmt, s.query, args...) |
||||
} |
||||
|
||||
func (s statement) QueryRow(v any, args ...any) error { |
||||
return s.QueryRowCtx(context.Background(), v, args...) |
||||
} |
||||
|
||||
func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) { |
||||
ctx, span := startSpan(ctx, "QueryRow") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { |
||||
return unmarshalRow(v, rows, true) |
||||
}, s.query, args...) |
||||
} |
||||
|
||||
func (s statement) QueryRowPartial(v any, args ...any) error { |
||||
return s.QueryRowPartialCtx(context.Background(), v, args...) |
||||
} |
||||
|
||||
func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) { |
||||
ctx, span := startSpan(ctx, "QueryRowPartial") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { |
||||
return unmarshalRow(v, rows, false) |
||||
}, s.query, args...) |
||||
} |
||||
|
||||
func (s statement) QueryRows(v any, args ...any) error { |
||||
return s.QueryRowsCtx(context.Background(), v, args...) |
||||
} |
||||
|
||||
func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) { |
||||
ctx, span := startSpan(ctx, "QueryRows") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { |
||||
return unmarshalRows(v, rows, true) |
||||
}, s.query, args...) |
||||
} |
||||
|
||||
func (s statement) QueryRowsPartial(v any, args ...any) error { |
||||
return s.QueryRowsPartialCtx(context.Background(), v, args...) |
||||
} |
||||
|
||||
func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) { |
||||
ctx, span := startSpan(ctx, "QueryRowsPartial") |
||||
defer func() { |
||||
endSpan(span, err) |
||||
}() |
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { |
||||
return unmarshalRows(v, rows, false) |
||||
}, s.query, args...) |
||||
} |
@ -0,0 +1,107 @@ |
||||
package sqlx_test |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"git.diulo.com/mogfee/kit/mysql/sqlx" |
||||
"go.opentelemetry.io/otel" |
||||
"go.opentelemetry.io/otel/attribute" |
||||
"go.opentelemetry.io/otel/exporters/jaeger" |
||||
"go.opentelemetry.io/otel/sdk/resource" |
||||
"go.opentelemetry.io/otel/sdk/trace" |
||||
semconv "go.opentelemetry.io/otel/semconv/v1.17.0" |
||||
"io" |
||||
"testing" |
||||
"time" |
||||
|
||||
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace" |
||||
) |
||||
|
||||
var datasource = "root:123456@tcp(127.0.0.1:3306)/test_gozero" |
||||
var driverName = "mysql" |
||||
|
||||
func TestNewSqlConn(t *testing.T) { |
||||
sqlConn := sqlx.NewSqlConn(driverName, datasource) |
||||
//result, err := sqlConn.Exec("insert into user(name) value ('aa')")
|
||||
//if err != nil {
|
||||
// t.Error(err)
|
||||
//}
|
||||
//fmt.Println(result.LastInsertId())
|
||||
//var id int
|
||||
//if err := sqlConn.QueryRow(&id, "select name from user where id=?", 1); err != nil {
|
||||
// t.Error(err)
|
||||
//}
|
||||
//fmt.Println(id)
|
||||
//return
|
||||
type row struct { |
||||
Id int64 `db:"id"` |
||||
Name string `db:"name"` |
||||
Email string `db:"email"` |
||||
} |
||||
|
||||
//data := row{}
|
||||
//if err := sqlConn.QueryRow(&data, "select id,name,email from user"); err != nil {
|
||||
// t.Error(err)
|
||||
//}
|
||||
//list := make([]*row, 0)
|
||||
//if err := sqlConn.QueryRows(&list, "select id,name,email from user limit 2"); err != nil {
|
||||
// t.Error(err)
|
||||
//}
|
||||
//fmt.Printf("%+v\n", list)
|
||||
|
||||
//exp, err := newExporter(os.Stdout)
|
||||
//if err != nil {
|
||||
// panic(err)
|
||||
//}
|
||||
r, _ := resource.Merge(resource.Default(), resource.NewWithAttributes( |
||||
semconv.SchemaURL, |
||||
semconv.ServiceName("fib"), |
||||
semconv.ServiceVersion("v0.1.0"), |
||||
attribute.String("environment", "demo"), |
||||
)) |
||||
|
||||
exp, err := jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint("http://127.0.0.1:14268/api/traces"))) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
tp := trace.NewTracerProvider(trace.WithBatcher(exp), trace.WithResource(r)) |
||||
defer func() { |
||||
if err = tp.Shutdown(context.Background()); err != nil { |
||||
panic(err) |
||||
} |
||||
}() |
||||
otel.SetTracerProvider(tp) |
||||
ctx := context.Background() |
||||
newCtx, span := otel.Tracer("sql_client").Start(ctx, "Run") |
||||
err = sqlConn.Transact(func(session sqlx.Session) error { |
||||
data := row{} |
||||
err := session.QueryRowCtx(newCtx, &data, "select * from user where id=?", 1) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
fmt.Println(data.Name) |
||||
res, err := session.ExecCtx(newCtx, "update user set name=? where id=?", time.Now().Unix(), 1) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
fmt.Println(res.RowsAffected()) |
||||
fmt.Println(res.LastInsertId()) |
||||
return nil |
||||
}) |
||||
if err != nil { |
||||
t.Error(err) |
||||
} |
||||
span.End() |
||||
time.Sleep(time.Second) |
||||
} |
||||
|
||||
// newExporter returns a console exporter.
|
||||
func newExporter(w io.Writer) (trace.SpanExporter, error) { |
||||
return stdouttrace.New( |
||||
stdouttrace.WithWriter(w), |
||||
// Use human-readable output.
|
||||
stdouttrace.WithPrettyPrint(), |
||||
// Do not print timestamps for the demo.
|
||||
stdouttrace.WithoutTimestamps(), |
||||
) |
||||
} |
@ -0,0 +1,43 @@ |
||||
package sqlx |
||||
|
||||
import ( |
||||
"database/sql" |
||||
"git.diulo.com/mogfee/kit/syncx" |
||||
"io" |
||||
"time" |
||||
) |
||||
|
||||
const ( |
||||
maxIdleConns = 64 |
||||
maxOpenConns = 64 |
||||
maxLifetime = time.Minute |
||||
) |
||||
|
||||
var connManager = syncx.NewResourceManager() |
||||
|
||||
func getCachedSqlConn(driverName string, server string) (*sql.DB, error) { |
||||
val, err := connManager.GetResource(server, func() (io.Closer, error) { |
||||
return newDBConnection(driverName, server) |
||||
}) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return val.(*sql.DB), nil |
||||
} |
||||
|
||||
func getSqlConn(driverName, server string) (*sql.DB, error) { |
||||
return getCachedSqlConn(driverName, server) |
||||
} |
||||
func newDBConnection(driverName, datasource string) (*sql.DB, error) { |
||||
conn, err := sql.Open(driverName, datasource) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
conn.SetMaxIdleConns(maxIdleConns) |
||||
conn.SetMaxOpenConns(maxOpenConns) |
||||
conn.SetConnMaxLifetime(maxLifetime) |
||||
if err = conn.Ping(); err != nil { |
||||
return nil, err |
||||
} |
||||
return conn, nil |
||||
} |
@ -0,0 +1,31 @@ |
||||
package sqlx |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql" |
||||
) |
||||
|
||||
func exec(ctx context.Context, conn sessionConn, q string, args ...any) (sql.Result, error) { |
||||
result, err := conn.ExecContext(ctx, q, args...) |
||||
return result, err |
||||
} |
||||
func query(ctx context.Context, conn sessionConn, scanner func(rows *sql.Rows) error, q string, args ...any) error { |
||||
rows, err := conn.QueryContext(ctx, q, args...) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer rows.Close() |
||||
return scanner(rows) |
||||
} |
||||
func queryStmt(ctx context.Context, conn stmtConn, scanner func(rows *sql.Rows) error, query string, args ...any) error { |
||||
rows, err := conn.QueryContext(ctx, args...) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer rows.Close() |
||||
return scanner(rows) |
||||
} |
||||
func execStmt(ctx context.Context, conn stmtConn, q string, args ...any) (sql.Result, error) { |
||||
result, err := conn.ExecContext(ctx, args...) |
||||
return result, err |
||||
} |
@ -0,0 +1,28 @@ |
||||
package sqlx |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql" |
||||
"git.diulo.com/mogfee/kit/trace" |
||||
"go.opentelemetry.io/otel/attribute" |
||||
"go.opentelemetry.io/otel/codes" |
||||
oteltrace "go.opentelemetry.io/otel/trace" |
||||
) |
||||
|
||||
var sqlAttributeKey = attribute.Key("sql.method") |
||||
|
||||
func startSpan(ctx context.Context, method string) (context.Context, oteltrace.Span) { |
||||
tracer := trace.TracerFromContext(ctx) |
||||
start, span := tracer.Start(ctx, spanName, oteltrace.WithSpanKind(oteltrace.SpanKindClient)) |
||||
span.SetAttributes(sqlAttributeKey.String(method)) |
||||
return start, span |
||||
} |
||||
func endSpan(span oteltrace.Span, err error) { |
||||
defer span.End() |
||||
if err == nil || err == sql.ErrNoRows { |
||||
span.SetStatus(codes.Ok, "") |
||||
return |
||||
} |
||||
span.SetStatus(codes.Error, err.Error()) |
||||
span.RecordError(err) |
||||
} |
@ -0,0 +1,128 @@ |
||||
package sqlx |
||||
|
||||
import ( |
||||
"context" |
||||
"database/sql" |
||||
"fmt" |
||||
) |
||||
|
||||
type ( |
||||
beginnable func(*sql.DB) (trans, error) |
||||
trans interface { |
||||
Session |
||||
Commit() error |
||||
Rollback() error |
||||
} |
||||
txSession struct { |
||||
*sql.Tx |
||||
} |
||||
) |
||||
|
||||
func NewSession(tx *sql.Tx) Session { |
||||
return txSession{Tx: tx} |
||||
} |
||||
|
||||
func (t txSession) Exec(query string, args ...any) (sql.Result, error) { |
||||
return t.ExecContext(context.Background(), query, args...) |
||||
} |
||||
|
||||
func (t txSession) ExecCtx(ctx context.Context, query string, args ...any) (result sql.Result, err error) { |
||||
result, err = exec(ctx, t.Tx, query, args...) |
||||
return |
||||
} |
||||
|
||||
func (t txSession) Prepare(query string) (StmtSession, error) { |
||||
return t.PrepareCtx(context.Background(), query) |
||||
} |
||||
|
||||
func (t txSession) PrepareCtx(ctx context.Context, query string) (StmtSession, error) { |
||||
stmt, err := t.Tx.PrepareContext(ctx, query) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return statement{ |
||||
query: query, |
||||
stmt: stmt, |
||||
}, nil |
||||
} |
||||
|
||||
func (t txSession) QueryRow(v any, query string, args ...any) error { |
||||
return t.QueryRowCtx(context.Background(), v, query, args...) |
||||
} |
||||
|
||||
func (t txSession) QueryRowCtx(ctx context.Context, v any, q string, args ...any) error { |
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error { |
||||
return unmarshalRow(v, rows, true) |
||||
}, q, args...) |
||||
} |
||||
|
||||
func (t txSession) QueryRowPartial(v any, query string, args ...any) error { |
||||
return t.QueryRowPartialCtx(context.Background(), v, query, args...) |
||||
} |
||||
|
||||
func (t txSession) QueryRowPartialCtx(ctx context.Context, v any, q string, args ...any) error { |
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error { |
||||
return unmarshalRow(v, rows, false) |
||||
}, q, args...) |
||||
} |
||||
|
||||
func (t txSession) QueryRows(v any, query string, args ...any) error { |
||||
return t.QueryRowsCtx(context.Background(), v, query, args...) |
||||
} |
||||
|
||||
func (t txSession) QueryRowsCtx(ctx context.Context, v any, q string, args ...any) error { |
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error { |
||||
return unmarshalRows(v, rows, true) |
||||
}, q, args...) |
||||
} |
||||
|
||||
func (t txSession) QueryRowsPartial(v any, query string, args ...any) error { |
||||
return t.QueryRowsPartialCtx(context.Background(), v, query, args...) |
||||
} |
||||
|
||||
func (t txSession) QueryRowsPartialCtx(ctx context.Context, v any, q string, args ...any) error { |
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error { |
||||
return unmarshalRows(v, rows, false) |
||||
}, q, args...) |
||||
} |
||||
func begin(db *sql.DB) (trans, error) { |
||||
tx, err := db.Begin() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return txSession{ |
||||
Tx: tx, |
||||
}, nil |
||||
} |
||||
func transact(ctx context.Context, db *commonSqlConn, b beginnable, |
||||
fn func(context.Context, Session) error) (err error) { |
||||
conn, err := db.connProv() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return transactOnConn(ctx, conn, b, fn) |
||||
} |
||||
func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable, |
||||
fn func(context.Context, Session) error) (err error) { |
||||
var tx trans |
||||
tx, err = b(conn) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer func() { |
||||
if p := recover(); p != nil { |
||||
if e := tx.Rollback(); e != nil { |
||||
err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e) |
||||
} else { |
||||
err = fmt.Errorf("recoveer from %#v", p) |
||||
} |
||||
} else if err != nil { |
||||
if e := tx.Rollback(); e != nil { |
||||
err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e) |
||||
} |
||||
} else { |
||||
err = tx.Commit() |
||||
} |
||||
}() |
||||
return fn(ctx, tx) |
||||
} |
@ -0,0 +1,147 @@ |
||||
package sqlx |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"fmt" |
||||
"git.diulo.com/mogfee/kit/mapping" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
) |
||||
|
||||
var errUnbalancedEscape = errors.New("no char after escape char") |
||||
|
||||
func desensitize(datasource string) string { |
||||
pos := strings.LastIndex(datasource, "@") |
||||
if 0 <= pos && pos+1 < len(datasource) { |
||||
datasource = datasource[pos+1:] |
||||
} |
||||
return datasource |
||||
} |
||||
func escape(input string) string { |
||||
var b strings.Builder |
||||
for _, ch := range input { |
||||
switch ch { |
||||
case '\x00': |
||||
b.WriteString(`\x00`) |
||||
case '\r': |
||||
b.WriteString(`\r`) |
||||
case '\n': |
||||
b.WriteString(`\n`) |
||||
case '\\': |
||||
b.WriteString(`\\`) |
||||
case '\'': |
||||
b.WriteString(`\'`) |
||||
case '"': |
||||
b.WriteString(`"`) |
||||
case '\x1a': |
||||
b.WriteString(`\x1a`) |
||||
default: |
||||
b.WriteRune(ch) |
||||
} |
||||
} |
||||
return b.String() |
||||
} |
||||
func format(query string, args ...any) (string, error) { |
||||
numArgs := len(args) |
||||
if numArgs == 0 { |
||||
return query, nil |
||||
} |
||||
var b strings.Builder |
||||
var argIndex int |
||||
bytes := len(query) |
||||
for i := 0; i < bytes; i++ { |
||||
ch := query[i] |
||||
switch ch { |
||||
case '?': |
||||
if argIndex >= numArgs { |
||||
return "", fmt.Errorf("%d ? in sql, but less arguemnt provided", argIndex) |
||||
} |
||||
writeValue(&b, args[argIndex]) |
||||
argIndex++ |
||||
case ':', '$': |
||||
var j int |
||||
for j = i + 1; j < bytes; j++ { |
||||
char := query[j] |
||||
if char < '0' || '9' < char { |
||||
break |
||||
} |
||||
} |
||||
|
||||
if j > i+1 { |
||||
index, err := strconv.Atoi(query[i+1 : j]) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
if index > argIndex { |
||||
argIndex = index |
||||
} |
||||
index-- |
||||
if index < 0 || numArgs <= index { |
||||
return "", fmt.Errorf("wrong index %d in sql", index) |
||||
} |
||||
writeValue(&b, args[index]) |
||||
i = j - 1 |
||||
} |
||||
case '\'', '"', '`': |
||||
b.WriteByte(ch) |
||||
for j := i + 1; j < bytes; j++ { |
||||
cur := query[j] |
||||
b.WriteByte(cur) |
||||
if cur == '\\' { |
||||
j++ |
||||
if j >= bytes { |
||||
return "", errUnbalancedEscape |
||||
} |
||||
b.WriteByte(query[j]) |
||||
} else if cur == ch { |
||||
i = j |
||||
break |
||||
} |
||||
} |
||||
default: |
||||
b.WriteByte(ch) |
||||
} |
||||
} |
||||
if argIndex < numArgs { |
||||
return "", fmt.Errorf("%d arguments provided, not matching sql", argIndex) |
||||
} |
||||
return b.String(), nil |
||||
} |
||||
|
||||
func logInstanceError(ctx context.Context, datasource string, err error) { |
||||
datasource = desensitize(datasource) |
||||
fmt.Println(datasource, err) |
||||
//logx.WithContext(ctx).Errorf("Error on getting sql instance of %s: %v", datasource, err)
|
||||
} |
||||
func logSqlError(ctx context.Context, stmt string, err error) { |
||||
if err != nil && err != ErrNotFound { |
||||
fmt.Println(stmt, err) |
||||
//logx.WithContext(ctx).Errorf("Error on getting sql instance of %s: %v", datasource, err)
|
||||
} |
||||
} |
||||
func writeValue(buf *strings.Builder, arg any) { |
||||
switch v := arg.(type) { |
||||
case bool: |
||||
if v { |
||||
buf.WriteByte('1') |
||||
} else { |
||||
buf.WriteByte('0') |
||||
} |
||||
case string: |
||||
buf.WriteByte('\'') |
||||
buf.WriteString(escape(v)) |
||||
buf.WriteByte('\'') |
||||
case time.Time: |
||||
buf.WriteByte('\'') |
||||
buf.WriteString(v.String()) |
||||
buf.WriteByte('\'') |
||||
case *time.Time: |
||||
buf.WriteByte('\'') |
||||
buf.WriteString(v.String()) |
||||
buf.WriteByte('\'') |
||||
default: |
||||
buf.WriteString(mapping.Repr(v)) |
||||
} |
||||
} |
@ -0,0 +1,17 @@ |
||||
package stringx |
||||
|
||||
import "unicode" |
||||
|
||||
func Ucfirst(str string) string { |
||||
for i, v := range str { |
||||
return string(unicode.ToUpper(v)) + str[i+1:] |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
func Lcfirst(str string) string { |
||||
for i, v := range str { |
||||
return string(unicode.ToLower(v)) + str[i+1:] |
||||
} |
||||
return "" |
||||
} |
@ -0,0 +1,26 @@ |
||||
package syncx |
||||
|
||||
import ( |
||||
"sync/atomic" |
||||
"time" |
||||
) |
||||
|
||||
type AtomicDuration int64 |
||||
|
||||
func NewAtomicDuration() *AtomicDuration { |
||||
return new(AtomicDuration) |
||||
} |
||||
func ForAtomicDuration(val time.Duration) *AtomicDuration { |
||||
d := NewAtomicDuration() |
||||
d.Set(val) |
||||
return d |
||||
} |
||||
func (d *AtomicDuration) CompareAndSwap(old, val time.Duration) bool { |
||||
return atomic.CompareAndSwapInt64((*int64)(d), int64(old), int64(val)) |
||||
} |
||||
func (d *AtomicDuration) Load() time.Duration { |
||||
return time.Duration(atomic.LoadInt64((*int64)(d))) |
||||
} |
||||
func (d *AtomicDuration) Set(val time.Duration) { |
||||
atomic.StoreInt64((*int64)(d), int64(val)) |
||||
} |
@ -0,0 +1,60 @@ |
||||
package syncx |
||||
|
||||
import ( |
||||
"git.diulo.com/mogfee/kit/errors" |
||||
"io" |
||||
"sync" |
||||
) |
||||
|
||||
type ResourceManager struct { |
||||
resource map[string]io.Closer |
||||
singleFlight SingleFlight |
||||
lock sync.Mutex |
||||
} |
||||
|
||||
func NewResourceManager() *ResourceManager { |
||||
return &ResourceManager{ |
||||
resource: make(map[string]io.Closer), |
||||
singleFlight: NewSingleFlight(), |
||||
} |
||||
} |
||||
func (manager *ResourceManager) Close() error { |
||||
manager.lock.Lock() |
||||
defer manager.lock.Unlock() |
||||
var be errors.BatchError |
||||
for _, resource := range manager.resource { |
||||
if err := resource.Close(); err != nil { |
||||
be.Add(err) |
||||
} |
||||
} |
||||
manager.resource = nil |
||||
return be.Err() |
||||
} |
||||
func (manager *ResourceManager) GetResource(key string, create func() (io.Closer, error)) (io.Closer, error) { |
||||
val, err := manager.singleFlight.Do(key, func() (any, error) { |
||||
manager.lock.Lock() |
||||
resource, ok := manager.resource[key] |
||||
manager.lock.Unlock() |
||||
if ok { |
||||
return resource, nil |
||||
} |
||||
resource, err := create() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
manager.lock.Lock() |
||||
defer manager.lock.Unlock() |
||||
manager.resource[key] = resource |
||||
return resource, nil |
||||
}) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return val.(io.Closer), nil |
||||
} |
||||
|
||||
func (manager *ResourceManager) Inject(key string, resource io.Closer) { |
||||
manager.lock.Lock() |
||||
manager.resource[key] = resource |
||||
manager.lock.Unlock() |
||||
} |
@ -0,0 +1,67 @@ |
||||
package syncx |
||||
|
||||
import "sync" |
||||
|
||||
type ( |
||||
SingleFlight interface { |
||||
Do(key string, fn func() (any, error)) (any, error) |
||||
DoEx(key string, fn func() (any, error)) (any, bool, error) |
||||
} |
||||
call struct { |
||||
wg sync.WaitGroup |
||||
val any |
||||
err error |
||||
} |
||||
flightGroup struct { |
||||
calls map[string]*call |
||||
lock sync.Mutex |
||||
} |
||||
) |
||||
|
||||
func NewSingleFlight() SingleFlight { |
||||
return &flightGroup{ |
||||
calls: make(map[string]*call), |
||||
} |
||||
} |
||||
|
||||
func (g *flightGroup) Do(key string, fn func() (any, error)) (any, error) { |
||||
c, done := g.createCall(key) |
||||
if done { |
||||
return c.val, c.err |
||||
} |
||||
g.makeCall(c, key, fn) |
||||
return c.val, c.err |
||||
} |
||||
|
||||
func (g *flightGroup) DoEx(key string, fn func() (any, error)) (v any, fresh bool, err error) { |
||||
c, done := g.createCall(key) |
||||
if done { |
||||
return c.val, false, c.err |
||||
} |
||||
g.makeCall(c, key, fn) |
||||
return c.val, true, c.err |
||||
} |
||||
func (g *flightGroup) createCall(key string) (c *call, done bool) { |
||||
g.lock.Lock() |
||||
//有在执行的等待结果返回
|
||||
if c, ok := g.calls[key]; ok { |
||||
g.lock.Unlock() |
||||
c.wg.Wait() |
||||
return c, true |
||||
} |
||||
//创建
|
||||
c = new(call) |
||||
c.wg.Add(1) |
||||
g.calls[key] = c |
||||
g.lock.Unlock() |
||||
return c, false |
||||
} |
||||
func (g *flightGroup) makeCall(c *call, key string, fn func() (any, error)) { |
||||
defer func() { |
||||
g.lock.Lock() |
||||
delete(g.calls, key) |
||||
g.lock.Unlock() |
||||
c.wg.Done() |
||||
}() |
||||
c.val, c.err = fn() |
||||
} |
@ -0,0 +1,13 @@ |
||||
package timex |
||||
|
||||
import "time" |
||||
|
||||
var initTime = time.Now().AddDate(-1, -1, -1) |
||||
|
||||
func Now() time.Duration { |
||||
return time.Since(initTime) |
||||
} |
||||
|
||||
func Since(d time.Duration) time.Duration { |
||||
return time.Since(initTime) - d |
||||
} |
@ -0,0 +1,64 @@ |
||||
package trace |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"go.opentelemetry.io/otel" |
||||
"go.opentelemetry.io/otel/exporters/jaeger" |
||||
"go.opentelemetry.io/otel/sdk/resource" |
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace" |
||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0" |
||||
|
||||
"net/url" |
||||
"sync" |
||||
) |
||||
|
||||
const ( |
||||
kindJaeger = "jaeger" |
||||
) |
||||
|
||||
var ( |
||||
lock sync.Mutex |
||||
tp *sdktrace.TracerProvider |
||||
) |
||||
|
||||
func StartAgent(c Config) { |
||||
lock.Lock() |
||||
defer lock.Unlock() |
||||
if err := startAgent(c); err != nil { |
||||
return |
||||
} |
||||
} |
||||
func StopAgent() { |
||||
_ = tp.Shutdown(context.Background()) |
||||
} |
||||
|
||||
func createExporter(c Config) (sdktrace.SpanExporter, error) { |
||||
//switch c.Batcher {
|
||||
//case kindJaeger:
|
||||
u, _ := url.Parse(c.Endpoint) |
||||
if u.Scheme == "udp" { |
||||
return jaeger.New(jaeger.WithAgentEndpoint(jaeger.WithAgentHost(u.Host), jaeger.WithAgentPort(u.Port()))) |
||||
} |
||||
return jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(u.Host))) |
||||
//}
|
||||
} |
||||
func startAgent(c Config) error { |
||||
opts := []sdktrace.TracerProviderOption{ |
||||
sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(c.Sampler))), |
||||
sdktrace.WithResource(resource.NewSchemaless(semconv.ServiceNameKey.String(c.Name))), |
||||
} |
||||
if len(c.Endpoint) > 0 { |
||||
exp, err := createExporter(c) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
opts = append(opts, sdktrace.WithBatcher(exp)) |
||||
} |
||||
tp = sdktrace.NewTracerProvider(opts...) |
||||
otel.SetTracerProvider(tp) |
||||
otel.SetErrorHandler(otel.ErrorHandlerFunc(func(err error) { |
||||
fmt.Println("[otel]", err) |
||||
})) |
||||
return nil |
||||
} |
@ -0,0 +1,24 @@ |
||||
package trace |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"go.opentelemetry.io/otel" |
||||
"testing" |
||||
) |
||||
|
||||
func TestStartAgent(t *testing.T) { |
||||
StartAgent(Config{ |
||||
Name: "my-server", |
||||
Endpoint: "http://127.0.0.1:14268/api/traces", |
||||
Sampler: 1.0, |
||||
Batcher: "jaeger", |
||||
}) |
||||
ctx := context.Background() |
||||
newCtx, span := otel.Tracer("sql_client").Start(ctx, "Run") |
||||
tr := TracerFromContext(newCtx) |
||||
_, span1 := tr.Start(newCtx, "abc") |
||||
span1.End() |
||||
span.End() |
||||
fmt.Println(111) |
||||
} |
@ -0,0 +1,10 @@ |
||||
package trace |
||||
|
||||
const TraceName = "kit" |
||||
|
||||
type Config struct { |
||||
Name string |
||||
Endpoint string |
||||
Sampler float64 |
||||
Batcher string |
||||
} |
@ -0,0 +1,20 @@ |
||||
package trace |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"go.opentelemetry.io/otel" |
||||
"go.opentelemetry.io/otel/trace" |
||||
) |
||||
|
||||
func TracerFromContext(ctx context.Context) (tracer trace.Tracer) { |
||||
if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() { |
||||
fmt.Println("find") |
||||
tracer = span.TracerProvider().Tracer(TraceName) |
||||
} else { |
||||
fmt.Println("not find") |
||||
|
||||
tracer = otel.Tracer(TraceName) |
||||
} |
||||
return |
||||
} |
Loading…
Reference in new issue