You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
315 lines
9.1 KiB
315 lines
9.1 KiB
package sqlx |
|
|
|
import ( |
|
"context" |
|
"database/sql" |
|
) |
|
|
|
const spanName = "sql" |
|
|
|
var ErrNotFound = sql.ErrNoRows |
|
|
|
type ( |
|
Session interface { |
|
Exec(query string, args ...any) (sql.Result, error) |
|
ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) |
|
Prepare(query string) (StmtSession, error) |
|
PrepareCtx(ctx context.Context, query string) (StmtSession, error) |
|
QueryRow(v any, query string, args ...any) error |
|
QueryRowCtx(ctx context.Context, v any, query string, args ...any) error |
|
QueryRowPartial(v any, query string, args ...any) error |
|
QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error |
|
QueryRows(v any, query string, args ...any) error |
|
QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error |
|
QueryRowsPartial(v any, query string, args ...any) error |
|
QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error |
|
} |
|
SqlConn interface { |
|
Session |
|
RawDB() (*sql.DB, error) |
|
Transact(fn func(Session) error) error |
|
TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error |
|
} |
|
SqlOption func(*commonSqlConn) |
|
StmtSession interface { |
|
Close() error |
|
Exec(args ...any) (sql.Result, error) |
|
ExecCtx(ctx context.Context, args ...any) (sql.Result, error) |
|
QueryRow(v any, args ...any) error |
|
QueryRowCtx(ctx context.Context, v any, args ...any) error |
|
QueryRowPartial(v any, args ...any) error |
|
QueryRowPartialCtx(ctx context.Context, v any, args ...any) error |
|
QueryRows(v any, args ...any) error |
|
QueryRowsCtx(ctx context.Context, v any, args ...any) error |
|
QueryRowsPartial(v any, args ...any) error |
|
QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error |
|
} |
|
|
|
commonSqlConn struct { |
|
connProv connProvider |
|
onError func(ctx context.Context, err error) |
|
accept func(error) bool |
|
beginTx beginnable |
|
} |
|
connProvider func() (db *sql.DB, err error) |
|
|
|
sessionConn interface { |
|
Exec(query string, args ...any) (sql.Result, error) |
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) |
|
Query(query string, args ...any) (*sql.Rows, error) |
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) |
|
} |
|
statement struct { |
|
query string |
|
stmt *sql.Stmt |
|
} |
|
stmtConn interface { |
|
Exec(args ...any) (sql.Result, error) |
|
ExecContext(ctx context.Context, args ...any) (sql.Result, error) |
|
Query(args ...any) (*sql.Rows, error) |
|
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) |
|
} |
|
) |
|
|
|
func NewSqlConn(driverName string, datasource string, opts ...SqlOption) SqlConn { |
|
conn := &commonSqlConn{ |
|
connProv: func() (db *sql.DB, err error) { |
|
return getSqlConn(driverName, datasource) |
|
}, |
|
onError: func(ctx context.Context, err error) { |
|
logInstanceError(ctx, datasource, err) |
|
}, |
|
beginTx: begin, |
|
} |
|
for _, opt := range opts { |
|
opt(conn) |
|
} |
|
return conn |
|
} |
|
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn { |
|
conn := &commonSqlConn{ |
|
connProv: func() (db *sql.DB, err error) { |
|
return db, nil |
|
}, |
|
onError: func(ctx context.Context, err error) { |
|
//logx.WithContext(ctx).Errorf("Error on getting sql instance: %v", err) |
|
}, |
|
beginTx: begin, |
|
} |
|
for _, opt := range opts { |
|
opt(conn) |
|
} |
|
return conn |
|
} |
|
func (db *commonSqlConn) Exec(query string, args ...any) (sql.Result, error) { |
|
return db.ExecCtx(context.Background(), query, args...) |
|
} |
|
|
|
func (db *commonSqlConn) ExecCtx(ctx context.Context, query string, args ...any) (result sql.Result, err error) { |
|
ctx, span := startSpan(ctx, "Exec") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
conn, err := db.connProv() |
|
if err != nil { |
|
return nil, err |
|
} |
|
result, err = exec(ctx, conn, query, args...) |
|
return result, err |
|
} |
|
|
|
func (db *commonSqlConn) Prepare(query string) (StmtSession, error) { |
|
return db.PrepareCtx(context.Background(), query) |
|
} |
|
|
|
func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) { |
|
ctx, span := startSpan(ctx, "Prepare") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
var conn *sql.DB |
|
conn, err = db.connProv() |
|
if err != nil { |
|
return nil, err |
|
} |
|
var st *sql.Stmt |
|
st, err = conn.PrepareContext(ctx, query) |
|
if err != nil { |
|
return nil, err |
|
} |
|
stmt = statement{ |
|
stmt: st, |
|
query: query, |
|
} |
|
|
|
return |
|
} |
|
|
|
func (db *commonSqlConn) QueryRow(v any, query string, args ...any) error { |
|
return db.QueryRowCtx(context.Background(), v, query, args...) |
|
} |
|
|
|
func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, query string, args ...any) (err error) { |
|
ctx, span := startSpan(ctx, "QueryRow") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
|
|
err = db.queryRows(ctx, func(rows *sql.Rows) error { |
|
return unmarshalRow(v, rows, true) |
|
}, query, args...) |
|
return |
|
} |
|
func (db *commonSqlConn) QueryRowPartial(v any, query string, args ...any) error { |
|
return db.QueryRowPartialCtx(context.Background(), v, query, args...) |
|
} |
|
|
|
func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) (err error) { |
|
ctx, span := startSpan(ctx, "QueryRowPartial") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
err = db.queryRows(ctx, func(rows *sql.Rows) error { |
|
return unmarshalRow(v, rows, false) |
|
}, query, args...) |
|
return |
|
} |
|
|
|
func (db *commonSqlConn) QueryRows(v any, query string, args ...any) error { |
|
return db.QueryRowsCtx(context.Background(), v, query, args...) |
|
} |
|
|
|
func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) (err error) { |
|
ctx, span := startSpan(ctx, "QueryRows") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
err = db.queryRows(ctx, func(rows *sql.Rows) error { |
|
return unmarshalRows(v, rows, true) |
|
}, query, args...) |
|
return err |
|
} |
|
|
|
func (db *commonSqlConn) QueryRowsPartial(v any, query string, args ...any) error { |
|
return db.QueryRowPartialCtx(context.Background(), v, query, args...) |
|
} |
|
|
|
func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) (err error) { |
|
ctx, span := startSpan(ctx, "QueryRowsPartial") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
err = db.queryRows(ctx, func(rows *sql.Rows) error { |
|
return unmarshalRows(v, rows, false) |
|
}, query, args...) |
|
return |
|
} |
|
|
|
func (db *commonSqlConn) RawDB() (*sql.DB, error) { |
|
return db.connProv() |
|
} |
|
|
|
func (db *commonSqlConn) Transact(fn func(Session) error) error { |
|
return db.TransactCtx(context.Background(), func(ctx context.Context, session Session) error { |
|
return fn(session) |
|
}) |
|
} |
|
|
|
func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) { |
|
ctx, span := startSpan(ctx, "QueryRowsPartial") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
|
|
err = transact(ctx, db, db.beginTx, fn) |
|
return |
|
} |
|
func (db *commonSqlConn) acceptable(err error) bool { |
|
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled |
|
if db.accept == nil { |
|
return ok |
|
} |
|
return ok || db.accept(err) |
|
} |
|
|
|
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(rows *sql.Rows) error, q string, args ...any) (err error) { |
|
conn, err := db.connProv() |
|
if err != nil { |
|
return err |
|
} |
|
return query(ctx, conn, func(rows *sql.Rows) error { |
|
return scanner(rows) |
|
}, q, args...) |
|
} |
|
|
|
func (s statement) Close() error { |
|
return s.stmt.Close() |
|
} |
|
|
|
func (s statement) Exec(args ...any) (sql.Result, error) { |
|
return s.ExecCtx(context.Background(), args...) |
|
} |
|
|
|
func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) { |
|
ctx, span := startSpan(ctx, "Exec") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
return execStmt(ctx, s.stmt, s.query, args...) |
|
} |
|
|
|
func (s statement) QueryRow(v any, args ...any) error { |
|
return s.QueryRowCtx(context.Background(), v, args...) |
|
} |
|
|
|
func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) { |
|
ctx, span := startSpan(ctx, "QueryRow") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { |
|
return unmarshalRow(v, rows, true) |
|
}, s.query, args...) |
|
} |
|
|
|
func (s statement) QueryRowPartial(v any, args ...any) error { |
|
return s.QueryRowPartialCtx(context.Background(), v, args...) |
|
} |
|
|
|
func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) { |
|
ctx, span := startSpan(ctx, "QueryRowPartial") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { |
|
return unmarshalRow(v, rows, false) |
|
}, s.query, args...) |
|
} |
|
|
|
func (s statement) QueryRows(v any, args ...any) error { |
|
return s.QueryRowsCtx(context.Background(), v, args...) |
|
} |
|
|
|
func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) { |
|
ctx, span := startSpan(ctx, "QueryRows") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { |
|
return unmarshalRows(v, rows, true) |
|
}, s.query, args...) |
|
} |
|
|
|
func (s statement) QueryRowsPartial(v any, args ...any) error { |
|
return s.QueryRowsPartialCtx(context.Background(), v, args...) |
|
} |
|
|
|
func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) { |
|
ctx, span := startSpan(ctx, "QueryRowsPartial") |
|
defer func() { |
|
endSpan(span, err) |
|
}() |
|
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { |
|
return unmarshalRows(v, rows, false) |
|
}, s.query, args...) |
|
}
|
|
|