package sqlx import ( "context" "database/sql" ) const spanName = "sql" var ErrNotFound = sql.ErrNoRows type ( Session interface { Exec(query string, args ...any) (sql.Result, error) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) Prepare(query string) (StmtSession, error) PrepareCtx(ctx context.Context, query string) (StmtSession, error) QueryRow(v any, query string, args ...any) error QueryRowCtx(ctx context.Context, v any, query string, args ...any) error QueryRowPartial(v any, query string, args ...any) error QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error QueryRows(v any, query string, args ...any) error QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error QueryRowsPartial(v any, query string, args ...any) error QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error } SqlConn interface { Session RawDB() (*sql.DB, error) Transact(fn func(Session) error) error TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error } SqlOption func(*commonSqlConn) StmtSession interface { Close() error Exec(args ...any) (sql.Result, error) ExecCtx(ctx context.Context, args ...any) (sql.Result, error) QueryRow(v any, args ...any) error QueryRowCtx(ctx context.Context, v any, args ...any) error QueryRowPartial(v any, args ...any) error QueryRowPartialCtx(ctx context.Context, v any, args ...any) error QueryRows(v any, args ...any) error QueryRowsCtx(ctx context.Context, v any, args ...any) error QueryRowsPartial(v any, args ...any) error QueryRowsPartialCtx(ctx context.Context, v any, args ...any) error } commonSqlConn struct { connProv connProvider onError func(ctx context.Context, err error) accept func(error) bool beginTx beginnable } connProvider func() (db *sql.DB, err error) sessionConn interface { Exec(query string, args ...any) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) Query(query string, args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } statement struct { query string stmt *sql.Stmt } stmtConn interface { Exec(args ...any) (sql.Result, error) ExecContext(ctx context.Context, args ...any) (sql.Result, error) Query(args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) } ) func NewSqlConn(driverName string, datasource string, opts ...SqlOption) SqlConn { conn := &commonSqlConn{ connProv: func() (db *sql.DB, err error) { return getSqlConn(driverName, datasource) }, onError: func(ctx context.Context, err error) { logInstanceError(ctx, datasource, err) }, beginTx: begin, } for _, opt := range opts { opt(conn) } return conn } func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn { conn := &commonSqlConn{ connProv: func() (db *sql.DB, err error) { return db, nil }, onError: func(ctx context.Context, err error) { //logx.WithContext(ctx).Errorf("Error on getting sql instance: %v", err) }, beginTx: begin, } for _, opt := range opts { opt(conn) } return conn } func (db *commonSqlConn) Exec(query string, args ...any) (sql.Result, error) { return db.ExecCtx(context.Background(), query, args...) } func (db *commonSqlConn) ExecCtx(ctx context.Context, query string, args ...any) (result sql.Result, err error) { ctx, span := startSpan(ctx, "Exec") defer func() { endSpan(span, err) }() conn, err := db.connProv() if err != nil { return nil, err } result, err = exec(ctx, conn, query, args...) return result, err } func (db *commonSqlConn) Prepare(query string) (StmtSession, error) { return db.PrepareCtx(context.Background(), query) } func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) { ctx, span := startSpan(ctx, "Prepare") defer func() { endSpan(span, err) }() var conn *sql.DB conn, err = db.connProv() if err != nil { return nil, err } var st *sql.Stmt st, err = conn.PrepareContext(ctx, query) if err != nil { return nil, err } stmt = statement{ stmt: st, query: query, } return } func (db *commonSqlConn) QueryRow(v any, query string, args ...any) error { return db.QueryRowCtx(context.Background(), v, query, args...) } func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v any, query string, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRow") defer func() { endSpan(span, err) }() err = db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRow(v, rows, true) }, query, args...) return } func (db *commonSqlConn) QueryRowPartial(v any, query string, args ...any) error { return db.QueryRowPartialCtx(context.Background(), v, query, args...) } func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRowPartial") defer func() { endSpan(span, err) }() err = db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRow(v, rows, false) }, query, args...) return } func (db *commonSqlConn) QueryRows(v any, query string, args ...any) error { return db.QueryRowsCtx(context.Background(), v, query, args...) } func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRows") defer func() { endSpan(span, err) }() err = db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRows(v, rows, true) }, query, args...) return err } func (db *commonSqlConn) QueryRowsPartial(v any, query string, args ...any) error { return db.QueryRowPartialCtx(context.Background(), v, query, args...) } func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRowsPartial") defer func() { endSpan(span, err) }() err = db.queryRows(ctx, func(rows *sql.Rows) error { return unmarshalRows(v, rows, false) }, query, args...) return } func (db *commonSqlConn) RawDB() (*sql.DB, error) { return db.connProv() } func (db *commonSqlConn) Transact(fn func(Session) error) error { return db.TransactCtx(context.Background(), func(ctx context.Context, session Session) error { return fn(session) }) } func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) { ctx, span := startSpan(ctx, "QueryRowsPartial") defer func() { endSpan(span, err) }() err = transact(ctx, db, db.beginTx, fn) return } func (db *commonSqlConn) acceptable(err error) bool { ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled if db.accept == nil { return ok } return ok || db.accept(err) } func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(rows *sql.Rows) error, q string, args ...any) (err error) { conn, err := db.connProv() if err != nil { return err } return query(ctx, conn, func(rows *sql.Rows) error { return scanner(rows) }, q, args...) } func (s statement) Close() error { return s.stmt.Close() } func (s statement) Exec(args ...any) (sql.Result, error) { return s.ExecCtx(context.Background(), args...) } func (s statement) ExecCtx(ctx context.Context, args ...any) (result sql.Result, err error) { ctx, span := startSpan(ctx, "Exec") defer func() { endSpan(span, err) }() return execStmt(ctx, s.stmt, s.query, args...) } func (s statement) QueryRow(v any, args ...any) error { return s.QueryRowCtx(context.Background(), v, args...) } func (s statement) QueryRowCtx(ctx context.Context, v any, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRow") defer func() { endSpan(span, err) }() return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRow(v, rows, true) }, s.query, args...) } func (s statement) QueryRowPartial(v any, args ...any) error { return s.QueryRowPartialCtx(context.Background(), v, args...) } func (s statement) QueryRowPartialCtx(ctx context.Context, v any, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRowPartial") defer func() { endSpan(span, err) }() return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRow(v, rows, false) }, s.query, args...) } func (s statement) QueryRows(v any, args ...any) error { return s.QueryRowsCtx(context.Background(), v, args...) } func (s statement) QueryRowsCtx(ctx context.Context, v any, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRows") defer func() { endSpan(span, err) }() return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRows(v, rows, true) }, s.query, args...) } func (s statement) QueryRowsPartial(v any, args ...any) error { return s.QueryRowsPartialCtx(context.Background(), v, args...) } func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any) (err error) { ctx, span := startSpan(ctx, "QueryRowsPartial") defer func() { endSpan(span, err) }() return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error { return unmarshalRows(v, rows, false) }, s.query, args...) }