You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

315 lines
9.1 KiB

package sqlx
import (
"context"
"database/sql"
)
const spanName = "sql"
var ErrNotFound = sql.ErrNoRows
type (
Session interface {
Exec(query string, args ...any) (sql.Result, error)
ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error)
Prepare(query string) (StmtSession, error)
PrepareCtx(ctx context.Context, query string) (StmtSession, error)
QueryRow(v any, query string, args ...any) error
QueryRowCtx(ctx context.Context, v any, query string, args ...any) error
QueryRowPartial(v any, query string, args ...any) error
QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error
QueryRows(v any, query string, args ...any) error
QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error
QueryRowsPartial(v any, query string, args ...any) error
QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error
}
SqlConn interface {
Session
RawDB() (*sql.DB, error)
Transact(fn func(Session) error) error
TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
}
SqlOption func(*commonSqlConn)
StmtSession interface {
Close() error
Exec(args ...any) (sql.Result, error)
ExecCtx(ctx context.Context, args ...any) (sql.Result, error)
QueryRow(v any, args ...any) error
QueryRowCtx(ctx context.Context, v any, args ...any) error
QueryRowPartial(v any, args ...any) error
QueryRowPartialCtx(ctx context.Context, v any, args ...any) error
QueryRows(v any, args ...any) error
QueryRowsCtx(ctx context.Context, v any, args ...any) error
QueryRowsPartial(v any, args ...any) error
QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error
}
commonSqlConn struct {
connProv connProvider
onError func(ctx context.Context, err error)
accept func(error) bool
beginTx beginnable
}
connProvider func() (db *sql.DB, err error)
sessionConn interface {
Exec(query string, args ...any) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
statement struct {
query string
stmt *sql.Stmt
}
stmtConn interface {
Exec(args ...any) (sql.Result, error)
ExecContext(ctx context.Context, args ...any) (sql.Result, error)
Query(args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
}
)
func NewSqlConn(driverName string, datasource string, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (db *sql.DB, err error) {
return getSqlConn(driverName, datasource)
},
onError: func(ctx context.Context, err error) {
logInstanceError(ctx, datasource, err)
},
beginTx: begin,
}
for _, opt := range opts {
opt(conn)
}
return conn
}
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (db *sql.DB, err error) {
return db, nil
},
onError: func(ctx context.Context, err error) {
//logx.WithContext(ctx).Errorf("Error on getting sql instance: %v", err)
},
beginTx: begin,
}
for _, opt := range opts {
opt(conn)
}
return conn
}
func (db *commonSqlConn) Exec(query string, args ...any) (sql.Result, error) {
return db.ExecCtx(context.Background(), query, args...)
}
func (db *commonSqlConn) ExecCtx(ctx context.Context, query string, args ...any) (result sql.Result, err error) {
ctx, span := startSpan(ctx, "Exec")
defer func() {
endSpan(span, err)
}()
conn, err := db.connProv()
if err != nil {
return nil, err
}
result, err = exec(ctx, conn, query, args...)
return result, err
}
func (db *commonSqlConn) Prepare(query string) (StmtSession, error) {
return db.PrepareCtx(context.Background(), query)
}
func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
ctx, span := startSpan(ctx, "Prepare")
defer func() {
endSpan(span, err)
}()
var conn *sql.DB
conn, err = db.connProv()
if err != nil {
return nil, err
}
var st *sql.Stmt
st, err = conn.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
stmt = statement{
stmt: st,
query: query,
}
return
}
func (db *commonSqlConn) QueryRow(v any, query string, args ...any) error {
return db.QueryRowCtx(context.Background(), v, query, args...)
}
func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, query string, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRow")
defer func() {
endSpan(span, err)
}()
err = db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, query, args...)
return
}
func (db *commonSqlConn) QueryRowPartial(v any, query string, args ...any) error {
return db.QueryRowPartialCtx(context.Background(), v, query, args...)
}
func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowPartial")
defer func() {
endSpan(span, err)
}()
err = db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, query, args...)
return
}
func (db *commonSqlConn) QueryRows(v any, query string, args ...any) error {
return db.QueryRowsCtx(context.Background(), v, query, args...)
}
func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRows")
defer func() {
endSpan(span, err)
}()
err = db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, query, args...)
return err
}
func (db *commonSqlConn) QueryRowsPartial(v any, query string, args ...any) error {
return db.QueryRowPartialCtx(context.Background(), v, query, args...)
}
func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowsPartial")
defer func() {
endSpan(span, err)
}()
err = db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, query, args...)
return
}
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
return db.connProv()
}
func (db *commonSqlConn) Transact(fn func(Session) error) error {
return db.TransactCtx(context.Background(), func(ctx context.Context, session Session) error {
return fn(session)
})
}
func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) {
ctx, span := startSpan(ctx, "QueryRowsPartial")
defer func() {
endSpan(span, err)
}()
err = transact(ctx, db, db.beginTx, fn)
return
}
func (db *commonSqlConn) acceptable(err error) bool {
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
if db.accept == nil {
return ok
}
return ok || db.accept(err)
}
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(rows *sql.Rows) error, q string, args ...any) (err error) {
conn, err := db.connProv()
if err != nil {
return err
}
return query(ctx, conn, func(rows *sql.Rows) error {
return scanner(rows)
}, q, args...)
}
func (s statement) Close() error {
return s.stmt.Close()
}
func (s statement) Exec(args ...any) (sql.Result, error) {
return s.ExecCtx(context.Background(), args...)
}
func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) {
ctx, span := startSpan(ctx, "Exec")
defer func() {
endSpan(span, err)
}()
return execStmt(ctx, s.stmt, s.query, args...)
}
func (s statement) QueryRow(v any, args ...any) error {
return s.QueryRowCtx(context.Background(), v, args...)
}
func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRow")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, s.query, args...)
}
func (s statement) QueryRowPartial(v any, args ...any) error {
return s.QueryRowPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowPartial")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, s.query, args...)
}
func (s statement) QueryRows(v any, args ...any) error {
return s.QueryRowsCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRows")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, s.query, args...)
}
func (s statement) QueryRowsPartial(v any, args ...any) error {
return s.QueryRowsPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) {
ctx, span := startSpan(ctx, "QueryRowsPartial")
defer func() {
endSpan(span, err)
}()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, s.query, args...)
}