|
|
|
package cors
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"git.diulo.com/mogfee/kit/middleware"
|
|
|
|
"git.diulo.com/mogfee/kit/transport"
|
|
|
|
"net/http"
|
|
|
|
)
|
|
|
|
|
|
|
|
type OptionFunc func(o *option)
|
|
|
|
|
|
|
|
func WithDomain(domain string) OptionFunc {
|
|
|
|
return func(o *option) {
|
|
|
|
o.domain = domain
|
|
|
|
}
|
|
|
|
}
|
|
|
|
func WithMethod(method string) OptionFunc {
|
|
|
|
return func(o *option) {
|
|
|
|
o.method = method
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func WithHeader(header string) OptionFunc {
|
|
|
|
return func(o *option) {
|
|
|
|
o.headers = header
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type option struct {
|
|
|
|
domain string
|
|
|
|
method string
|
|
|
|
headers string
|
|
|
|
}
|
|
|
|
|
|
|
|
func Cors(ops ...OptionFunc) middleware.Middleware {
|
|
|
|
cfg := &option{
|
|
|
|
domain: "*",
|
|
|
|
method: "GET,POST,OPTIONS,PUT,PATCH,DELETE",
|
|
|
|
headers: "Content-Type,X-Requested-With,Access-Control-Allow-Credentials,User-Agent,Content-Length,Authorization",
|
|
|
|
}
|
|
|
|
for _, o := range ops {
|
|
|
|
o(cfg)
|
|
|
|
}
|
|
|
|
return func(handler middleware.Handler) middleware.Handler {
|
|
|
|
return func(ctx context.Context, a any) (any, error) {
|
|
|
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
|
|
|
header := tr.ReplyHeader()
|
|
|
|
header.Set("Access-Control-Allow-Origin", cfg.domain)
|
|
|
|
header.Set("Access-Control-Allow-Methods", cfg.method)
|
|
|
|
header.Set("Access-Control-Allow-Credentials", "true")
|
|
|
|
header.Set("Access-Control-Allow-Headers", cfg.headers)
|
|
|
|
}
|
|
|
|
return handler(ctx, a)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func HttpServer() func(http.Handler) http.Handler {
|
|
|
|
return func(h http.Handler) http.Handler {
|
|
|
|
ch := &cors{h: h}
|
|
|
|
ch.h = h
|
|
|
|
return ch
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type cors struct {
|
|
|
|
h http.Handler
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
|
fmt.Println("cors start")
|
|
|
|
ch.h.ServeHTTP(w, r)
|
|
|
|
fmt.Println("cors end")
|
|
|
|
}
|