diff --git a/example/main.go b/example/main.go index 0901450..8dd4a8a 100644 --- a/example/main.go +++ b/example/main.go @@ -5,8 +5,6 @@ import ( "flag" "fmt" "git.diulo.com/mogfee/kit" - "git.diulo.com/mogfee/kit/middleware" - "git.diulo.com/mogfee/kit/transport" "git.diulo.com/mogfee/kit/transport/http" http2 "net/http" ) @@ -49,19 +47,7 @@ func runApp(host string) { // h: handler, // } //}), - http.Middleware(func(handler middleware.Handler) middleware.Handler { - return func(ctx context.Context, a any) (any, error) { - if tr, ok := transport.FromServerContext(ctx); ok { - header := tr.ReplyHeader() - header.Set("Access-Control-Allow-Origin", "*") - header.Set("Access-Control-Allow-Methods", "GET,POST,OPTIONS,PUT,PATCH,DELETE") - header.Set("Access-Control-Allow-Credentials", "true") - header.Set("Access-Control-Allow-Headers", "Content-Type,X-Requested-With,Access-Control-Allow-Credentials,User-Agent,Content-Length,Authorization") - } - - return handler(ctx, a) - } - }), + http.Middleware(), ) route := hs.Route("/") route.GET("/api/v1/answer/listCategory", func(ctx http.Context) error { diff --git a/middleware/cros/cros.go b/middleware/cros/cros.go new file mode 100644 index 0000000..8327a41 --- /dev/null +++ b/middleware/cros/cros.go @@ -0,0 +1,55 @@ +package cros + +import ( + "context" + "git.diulo.com/mogfee/kit/middleware" + "git.diulo.com/mogfee/kit/transport" +) + +type OptionFunc func(o *option) + +func WithDomain(domain string) OptionFunc { + return func(o *option) { + o.domain = domain + } +} +func WithMethod(method string) OptionFunc { + return func(o *option) { + o.method = method + } +} + +func WithHeader(header string) OptionFunc { + return func(o *option) { + o.headers = header + } +} + +type option struct { + domain string + method string + headers string +} + +func Cors(ops ...OptionFunc) middleware.Middleware { + cfg := &option{ + domain: "*", + method: "GET,POST,OPTIONS,PUT,PATCH,DELETE", + headers: "Content-Type,X-Requested-With,Access-Control-Allow-Credentials,User-Agent,Content-Length,Authorization", + } + for _, o := range ops { + o(cfg) + } + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, a any) (any, error) { + if tr, ok := transport.FromServerContext(ctx); ok { + header := tr.ReplyHeader() + header.Set("Access-Control-Allow-Origin", cfg.domain) + header.Set("Access-Control-Allow-Methods", cfg.method) + header.Set("Access-Control-Allow-Credentials", "true") + header.Set("Access-Control-Allow-Headers", cfg.headers) + } + return handler(ctx, a) + } + } +}