You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
3.2 KiB
128 lines
3.2 KiB
package sqlx |
|
|
|
import ( |
|
"context" |
|
"database/sql" |
|
"fmt" |
|
) |
|
|
|
type ( |
|
beginnable func(*sql.DB) (trans, error) |
|
trans interface { |
|
Session |
|
Commit() error |
|
Rollback() error |
|
} |
|
txSession struct { |
|
*sql.Tx |
|
} |
|
) |
|
|
|
func NewSession(tx *sql.Tx) Session { |
|
return txSession{Tx: tx} |
|
} |
|
|
|
func (t txSession) Exec(query string, args ...any) (sql.Result, error) { |
|
return t.ExecContext(context.Background(), query, args...) |
|
} |
|
|
|
func (t txSession) ExecCtx(ctx context.Context, query string, args ...any) (result sql.Result, err error) { |
|
result, err = exec(ctx, t.Tx, query, args...) |
|
return |
|
} |
|
|
|
func (t txSession) Prepare(query string) (StmtSession, error) { |
|
return t.PrepareCtx(context.Background(), query) |
|
} |
|
|
|
func (t txSession) PrepareCtx(ctx context.Context, query string) (StmtSession, error) { |
|
stmt, err := t.Tx.PrepareContext(ctx, query) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return statement{ |
|
query: query, |
|
stmt: stmt, |
|
}, nil |
|
} |
|
|
|
func (t txSession) QueryRow(v any, query string, args ...any) error { |
|
return t.QueryRowCtx(context.Background(), v, query, args...) |
|
} |
|
|
|
func (t txSession) QueryRowCtx(ctx context.Context, v any, q string, args ...any) error { |
|
return query(ctx, t.Tx, func(rows *sql.Rows) error { |
|
return unmarshalRow(v, rows, true) |
|
}, q, args...) |
|
} |
|
|
|
func (t txSession) QueryRowPartial(v any, query string, args ...any) error { |
|
return t.QueryRowPartialCtx(context.Background(), v, query, args...) |
|
} |
|
|
|
func (t txSession) QueryRowPartialCtx(ctx context.Context, v any, q string, args ...any) error { |
|
return query(ctx, t.Tx, func(rows *sql.Rows) error { |
|
return unmarshalRow(v, rows, false) |
|
}, q, args...) |
|
} |
|
|
|
func (t txSession) QueryRows(v any, query string, args ...any) error { |
|
return t.QueryRowsCtx(context.Background(), v, query, args...) |
|
} |
|
|
|
func (t txSession) QueryRowsCtx(ctx context.Context, v any, q string, args ...any) error { |
|
return query(ctx, t.Tx, func(rows *sql.Rows) error { |
|
return unmarshalRows(v, rows, true) |
|
}, q, args...) |
|
} |
|
|
|
func (t txSession) QueryRowsPartial(v any, query string, args ...any) error { |
|
return t.QueryRowsPartialCtx(context.Background(), v, query, args...) |
|
} |
|
|
|
func (t txSession) QueryRowsPartialCtx(ctx context.Context, v any, q string, args ...any) error { |
|
return query(ctx, t.Tx, func(rows *sql.Rows) error { |
|
return unmarshalRows(v, rows, false) |
|
}, q, args...) |
|
} |
|
func begin(db *sql.DB) (trans, error) { |
|
tx, err := db.Begin() |
|
if err != nil { |
|
return nil, err |
|
} |
|
return txSession{ |
|
Tx: tx, |
|
}, nil |
|
} |
|
func transact(ctx context.Context, db *commonSqlConn, b beginnable, |
|
fn func(context.Context, Session) error) (err error) { |
|
conn, err := db.connProv() |
|
if err != nil { |
|
return err |
|
} |
|
return transactOnConn(ctx, conn, b, fn) |
|
} |
|
func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable, |
|
fn func(context.Context, Session) error) (err error) { |
|
var tx trans |
|
tx, err = b(conn) |
|
if err != nil { |
|
return err |
|
} |
|
defer func() { |
|
if p := recover(); p != nil { |
|
if e := tx.Rollback(); e != nil { |
|
err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e) |
|
} else { |
|
err = fmt.Errorf("recoveer from %#v", p) |
|
} |
|
} else if err != nil { |
|
if e := tx.Rollback(); e != nil { |
|
err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e) |
|
} |
|
} else { |
|
err = tx.Commit() |
|
} |
|
}() |
|
return fn(ctx, tx) |
|
}
|
|
|