You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
316 lines
9.1 KiB
316 lines
9.1 KiB
1 year ago
|
package sqlx
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
)
|
||
|
|
||
|
const spanName = "sql"
|
||
|
|
||
|
var ErrNotFound = sql.ErrNoRows
|
||
|
|
||
|
type (
|
||
|
Session interface {
|
||
|
Exec(query string, args ...any) (sql.Result, error)
|
||
|
ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||
|
Prepare(query string) (StmtSession, error)
|
||
|
PrepareCtx(ctx context.Context, query string) (StmtSession, error)
|
||
|
QueryRow(v any, query string, args ...any) error
|
||
|
QueryRowCtx(ctx context.Context, v any, query string, args ...any) error
|
||
|
QueryRowPartial(v any, query string, args ...any) error
|
||
|
QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error
|
||
|
QueryRows(v any, query string, args ...any) error
|
||
|
QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error
|
||
|
QueryRowsPartial(v any, query string, args ...any) error
|
||
|
QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error
|
||
|
}
|
||
|
SqlConn interface {
|
||
|
Session
|
||
|
RawDB() (*sql.DB, error)
|
||
|
Transact(fn func(Session) error) error
|
||
|
TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
|
||
|
}
|
||
|
SqlOption func(*commonSqlConn)
|
||
|
StmtSession interface {
|
||
|
Close() error
|
||
|
Exec(args ...any) (sql.Result, error)
|
||
|
ExecCtx(ctx context.Context, args ...any) (sql.Result, error)
|
||
|
QueryRow(v any, args ...any) error
|
||
|
QueryRowCtx(ctx context.Context, v any, args ...any) error
|
||
|
QueryRowPartial(v any, args ...any) error
|
||
|
QueryRowPartialCtx(ctx context.Context, v any, args ...any) error
|
||
|
QueryRows(v any, args ...any) error
|
||
|
QueryRowsCtx(ctx context.Context, v any, args ...any) error
|
||
|
QueryRowsPartial(v any, args ...any) error
|
||
|
QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error
|
||
|
}
|
||
|
|
||
|
commonSqlConn struct {
|
||
|
connProv connProvider
|
||
|
onError func(ctx context.Context, err error)
|
||
|
accept func(error) bool
|
||
|
beginTx beginnable
|
||
|
}
|
||
|
connProvider func() (db *sql.DB, err error)
|
||
|
|
||
|
sessionConn interface {
|
||
|
Exec(query string, args ...any) (sql.Result, error)
|
||
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||
|
Query(query string, args ...any) (*sql.Rows, error)
|
||
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||
|
}
|
||
|
statement struct {
|
||
|
query string
|
||
|
stmt *sql.Stmt
|
||
|
}
|
||
|
stmtConn interface {
|
||
|
Exec(args ...any) (sql.Result, error)
|
||
|
ExecContext(ctx context.Context, args ...any) (sql.Result, error)
|
||
|
Query(args ...any) (*sql.Rows, error)
|
||
|
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
|
||
|
}
|
||
|
)
|
||
|
|
||
|
func NewSqlConn(driverName string, datasource string, opts ...SqlOption) SqlConn {
|
||
|
conn := &commonSqlConn{
|
||
|
connProv: func() (db *sql.DB, err error) {
|
||
|
return getSqlConn(driverName, datasource)
|
||
|
},
|
||
|
onError: func(ctx context.Context, err error) {
|
||
|
logInstanceError(ctx, datasource, err)
|
||
|
},
|
||
|
beginTx: begin,
|
||
|
}
|
||
|
for _, opt := range opts {
|
||
|
opt(conn)
|
||
|
}
|
||
|
return conn
|
||
|
}
|
||
|
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
||
|
conn := &commonSqlConn{
|
||
|
connProv: func() (db *sql.DB, err error) {
|
||
|
return db, nil
|
||
|
},
|
||
|
onError: func(ctx context.Context, err error) {
|
||
|
//logx.WithContext(ctx).Errorf("Error on getting sql instance: %v", err)
|
||
|
},
|
||
|
beginTx: begin,
|
||
|
}
|
||
|
for _, opt := range opts {
|
||
|
opt(conn)
|
||
|
}
|
||
|
return conn
|
||
|
}
|
||
|
func (db *commonSqlConn) Exec(query string, args ...any) (sql.Result, error) {
|
||
|
return db.ExecCtx(context.Background(), query, args...)
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) ExecCtx(ctx context.Context, query string, args ...any) (result sql.Result, err error) {
|
||
|
ctx, span := startSpan(ctx, "Exec")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
conn, err := db.connProv()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
result, err = exec(ctx, conn, query, args...)
|
||
|
return result, err
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) Prepare(query string) (StmtSession, error) {
|
||
|
return db.PrepareCtx(context.Background(), query)
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
|
||
|
ctx, span := startSpan(ctx, "Prepare")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
var conn *sql.DB
|
||
|
conn, err = db.connProv()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
var st *sql.Stmt
|
||
|
st, err = conn.PrepareContext(ctx, query)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
stmt = statement{
|
||
|
stmt: st,
|
||
|
query: query,
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) QueryRow(v any, query string, args ...any) error {
|
||
|
return db.QueryRowCtx(context.Background(), v, query, args...)
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, query string, args ...any) (err error) {
|
||
|
ctx, span := startSpan(ctx, "QueryRow")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
|
||
|
err = db.queryRows(ctx, func(rows *sql.Rows) error {
|
||
|
return unmarshalRow(v, rows, true)
|
||
|
}, query, args...)
|
||
|
return
|
||
|
}
|
||
|
func (db *commonSqlConn) QueryRowPartial(v any, query string, args ...any) error {
|
||
|
return db.QueryRowPartialCtx(context.Background(), v, query, args...)
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) (err error) {
|
||
|
ctx, span := startSpan(ctx, "QueryRowPartial")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
err = db.queryRows(ctx, func(rows *sql.Rows) error {
|
||
|
return unmarshalRow(v, rows, false)
|
||
|
}, query, args...)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) QueryRows(v any, query string, args ...any) error {
|
||
|
return db.QueryRowsCtx(context.Background(), v, query, args...)
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) (err error) {
|
||
|
ctx, span := startSpan(ctx, "QueryRows")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
err = db.queryRows(ctx, func(rows *sql.Rows) error {
|
||
|
return unmarshalRows(v, rows, true)
|
||
|
}, query, args...)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) QueryRowsPartial(v any, query string, args ...any) error {
|
||
|
return db.QueryRowPartialCtx(context.Background(), v, query, args...)
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) (err error) {
|
||
|
ctx, span := startSpan(ctx, "QueryRowsPartial")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
err = db.queryRows(ctx, func(rows *sql.Rows) error {
|
||
|
return unmarshalRows(v, rows, false)
|
||
|
}, query, args...)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
|
||
|
return db.connProv()
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
||
|
return db.TransactCtx(context.Background(), func(ctx context.Context, session Session) error {
|
||
|
return fn(session)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) {
|
||
|
ctx, span := startSpan(ctx, "QueryRowsPartial")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
|
||
|
err = transact(ctx, db, db.beginTx, fn)
|
||
|
return
|
||
|
}
|
||
|
func (db *commonSqlConn) acceptable(err error) bool {
|
||
|
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
|
||
|
if db.accept == nil {
|
||
|
return ok
|
||
|
}
|
||
|
return ok || db.accept(err)
|
||
|
}
|
||
|
|
||
|
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(rows *sql.Rows) error, q string, args ...any) (err error) {
|
||
|
conn, err := db.connProv()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return query(ctx, conn, func(rows *sql.Rows) error {
|
||
|
return scanner(rows)
|
||
|
}, q, args...)
|
||
|
}
|
||
|
|
||
|
func (s statement) Close() error {
|
||
|
return s.stmt.Close()
|
||
|
}
|
||
|
|
||
|
func (s statement) Exec(args ...any) (sql.Result, error) {
|
||
|
return s.ExecCtx(context.Background(), args...)
|
||
|
}
|
||
|
|
||
|
func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) {
|
||
|
ctx, span := startSpan(ctx, "Exec")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
return execStmt(ctx, s.stmt, s.query, args...)
|
||
|
}
|
||
|
|
||
|
func (s statement) QueryRow(v any, args ...any) error {
|
||
|
return s.QueryRowCtx(context.Background(), v, args...)
|
||
|
}
|
||
|
|
||
|
func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) {
|
||
|
ctx, span := startSpan(ctx, "QueryRow")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||
|
return unmarshalRow(v, rows, true)
|
||
|
}, s.query, args...)
|
||
|
}
|
||
|
|
||
|
func (s statement) QueryRowPartial(v any, args ...any) error {
|
||
|
return s.QueryRowPartialCtx(context.Background(), v, args...)
|
||
|
}
|
||
|
|
||
|
func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) {
|
||
|
ctx, span := startSpan(ctx, "QueryRowPartial")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||
|
return unmarshalRow(v, rows, false)
|
||
|
}, s.query, args...)
|
||
|
}
|
||
|
|
||
|
func (s statement) QueryRows(v any, args ...any) error {
|
||
|
return s.QueryRowsCtx(context.Background(), v, args...)
|
||
|
}
|
||
|
|
||
|
func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) {
|
||
|
ctx, span := startSpan(ctx, "QueryRows")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||
|
return unmarshalRows(v, rows, true)
|
||
|
}, s.query, args...)
|
||
|
}
|
||
|
|
||
|
func (s statement) QueryRowsPartial(v any, args ...any) error {
|
||
|
return s.QueryRowsPartialCtx(context.Background(), v, args...)
|
||
|
}
|
||
|
|
||
|
func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) {
|
||
|
ctx, span := startSpan(ctx, "QueryRowsPartial")
|
||
|
defer func() {
|
||
|
endSpan(span, err)
|
||
|
}()
|
||
|
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||
|
return unmarshalRows(v, rows, false)
|
||
|
}, s.query, args...)
|
||
|
}
|