From ddbfd18e06443ac2f7ce993e6827095ae667df87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=BC=9F=E4=B9=90?= Date: Tue, 20 Jun 2023 13:58:30 +0800 Subject: [PATCH] alert --- Makefile | 3 +- api/user.http.go | 132 -------------- api/user.pb.go | 28 +-- api/user.ts | 7 + api/user_grpc.pb.go | 37 ++++ api/user_http.pb.go | 25 ++- example/main.go | 75 +------- example/service/service.go | 13 ++ middleware/cc/cc.go | 356 +++++++++++++++++++++++++++++++++++++ middleware/cors/cors.go | 20 +++ middleware/jwt/token.go | 5 +- proto/user.proto | 7 + 12 files changed, 491 insertions(+), 217 deletions(-) delete mode 100644 api/user.http.go create mode 100644 middleware/cc/cc.go diff --git a/Makefile b/Makefile index 3aeda1f..c14606f 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,7 @@ all: testts: make ts gen: - protoc -I ./third_party -I ./proto --ts_out=./api --go_out=./api --go-grpc_out=./api --kit_out=./api --gin-kit_out=./api --validate_out="lang=go:./api" ./proto/*.proto + #protoc -I ./third_party -I ./proto --ts_out=./api --go_out=./api --go-grpc_out=./api --kit_out=./api --gin-kit_out=./api --validate_out="lang=go:./api" ./proto/*.proto + protoc -I ./third_party -I ./proto --ts_out=./api --go_out=./api --go-grpc_out=./api --kit_out=./api --validate_out="lang=go:./api" ./proto/*.proto auth: protoc --go_out=. ./third_party/auth/auth.proto \ No newline at end of file diff --git a/api/user.http.go b/api/user.http.go deleted file mode 100644 index ec4b198..0000000 --- a/api/user.http.go +++ /dev/null @@ -1,132 +0,0 @@ -package user - -import ( - "github.com/gin-gonic/gin" - "context" - "git.diulo.com/mogfee/kit/middleware" - "git.diulo.com/mogfee/kit/errors" - "git.diulo.com/mogfee/kit/response" -) - -func RegisterUserHandler(app *gin.Engine, srv UserServer, m ...middleware.Middleware) { - app.GET("/api/v1/user/list", httpListHandler(srv, m...)) - app.GET("/api/v1/user/all", httpAllHandler(srv, m...)) - app.GET("/api/v1/user/auto", httpAutoHandler(srv, m...)) - app.GET("/api/v1/user/login_list", httpLoginWithListHandler(srv, m...)) - app.GET("/api/v1/user/login", httpLoginHandler(srv, m...)) -} -func httpListHandler(srv UserServer, m ...middleware.Middleware) func(c *gin.Context) { - return func(c *gin.Context) { - var post Request - resp := response.New(c) - if err := resp.BindQuery(&post); err != nil { - resp.Error(err) - return - } - h := func(ctx context.Context, a any) (any, error) { - return srv.List(ctx, a.(*Request)) - } - out, err := middleware.HttpMiddleware(c, h, m...)(c, &post) - if err != nil { - resp.Error(err) - } else { - if v, ok := out.(*Response); ok { - resp.Success(v) - } else { - resp.Error(errors.InternalServer("RESULT_TYPE_ERROR", "Response")) - } - } - } -} -func httpAllHandler(srv UserServer, m ...middleware.Middleware) func(c *gin.Context) { - return func(c *gin.Context) { - var post Request - resp := response.New(c) - if err := resp.BindQuery(&post); err != nil { - resp.Error(err) - return - } - h := func(ctx context.Context, a any) (any, error) { - return srv.All(ctx, a.(*Request)) - } - out, err := middleware.HttpMiddleware(c, h, m...)(c, &post) - if err != nil { - resp.Error(err) - } else { - if v, ok := out.(*Response); ok { - resp.Success(v) - } else { - resp.Error(errors.InternalServer("RESULT_TYPE_ERROR", "Response")) - } - } - } -} -func httpAutoHandler(srv UserServer, m ...middleware.Middleware) func(c *gin.Context) { - return func(c *gin.Context) { - var post Request - resp := response.New(c) - if err := resp.BindQuery(&post); err != nil { - resp.Error(err) - return - } - h := func(ctx context.Context, a any) (any, error) { - return srv.Auto(ctx, a.(*Request)) - } - out, err := middleware.HttpMiddleware(c, h, m...)(c, &post) - if err != nil { - resp.Error(err) - } else { - if v, ok := out.(*Response); ok { - resp.Success(v) - } else { - resp.Error(errors.InternalServer("RESULT_TYPE_ERROR", "Response")) - } - } - } -} -func httpLoginWithListHandler(srv UserServer, m ...middleware.Middleware) func(c *gin.Context) { - return func(c *gin.Context) { - var post Request - resp := response.New(c) - if err := resp.BindQuery(&post); err != nil { - resp.Error(err) - return - } - h := func(ctx context.Context, a any) (any, error) { - return srv.LoginWithList(ctx, a.(*Request)) - } - out, err := middleware.HttpMiddleware(c, h, m...)(c, &post) - if err != nil { - resp.Error(err) - } else { - if v, ok := out.(*Response); ok { - resp.Success(v) - } else { - resp.Error(errors.InternalServer("RESULT_TYPE_ERROR", "Response")) - } - } - } -} -func httpLoginHandler(srv UserServer, m ...middleware.Middleware) func(c *gin.Context) { - return func(c *gin.Context) { - var post Request - resp := response.New(c) - if err := resp.BindQuery(&post); err != nil { - resp.Error(err) - return - } - h := func(ctx context.Context, a any) (any, error) { - return srv.Login(ctx, a.(*Request)) - } - out, err := middleware.HttpMiddleware(c, h, m...)(c, &post) - if err != nil { - resp.Error(err) - } else { - if v, ok := out.(*Response); ok { - resp.Success(v) - } else { - resp.Error(errors.InternalServer("RESULT_TYPE_ERROR", "Response")) - } - } - } -} diff --git a/api/user.pb.go b/api/user.pb.go index 0c566dc..90215b0 100644 --- a/api/user.pb.go +++ b/api/user.pb.go @@ -118,7 +118,7 @@ var file_user_proto_rawDesc = []byte{ 0x61, 0x75, 0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x20, 0x0a, 0x08, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x32, 0xcb, 0x03, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, + 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x32, 0xa6, 0x04, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x5e, 0x0a, 0x04, 0x6c, 0x69, 0x73, 0x74, 0x12, 0x16, 0x2e, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x69, 0x75, 0x6c, 0x6f, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x69, 0x75, 0x6c, 0x6f, 0x2e, 0x61, 0x70, 0x69, @@ -147,8 +147,14 @@ var file_user_proto_rawDesc = []byte{ 0x1a, 0x17, 0x2e, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x69, 0x75, 0x6c, 0x6f, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x14, 0x12, 0x12, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x75, 0x73, 0x65, 0x72, 0x2f, - 0x6c, 0x6f, 0x67, 0x69, 0x6e, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x3b, 0x75, 0x73, 0x65, 0x72, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x59, 0x0a, 0x06, 0x6c, 0x6f, 0x67, 0x69, 0x6e, 0x31, 0x12, + 0x16, 0x2e, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x69, 0x75, 0x6c, 0x6f, 0x2e, 0x61, 0x70, 0x69, 0x2e, + 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x63, 0x6f, 0x6d, 0x2e, 0x64, 0x69, + 0x75, 0x6c, 0x6f, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x1e, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x18, 0x22, 0x13, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, + 0x31, 0x2f, 0x75, 0x73, 0x65, 0x72, 0x2f, 0x6c, 0x6f, 0x67, 0x69, 0x6e, 0x31, 0x3a, 0x01, 0x2a, + 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x3b, 0x75, 0x73, 0x65, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, } var ( @@ -174,13 +180,15 @@ var file_user_proto_depIdxs = []int32{ 0, // 2: com.diulo.api.user.auto:input_type -> com.diulo.api.request 0, // 3: com.diulo.api.user.loginWithList:input_type -> com.diulo.api.request 0, // 4: com.diulo.api.user.login:input_type -> com.diulo.api.request - 1, // 5: com.diulo.api.user.list:output_type -> com.diulo.api.response - 1, // 6: com.diulo.api.user.all:output_type -> com.diulo.api.response - 1, // 7: com.diulo.api.user.auto:output_type -> com.diulo.api.response - 1, // 8: com.diulo.api.user.loginWithList:output_type -> com.diulo.api.response - 1, // 9: com.diulo.api.user.login:output_type -> com.diulo.api.response - 5, // [5:10] is the sub-list for method output_type - 0, // [0:5] is the sub-list for method input_type + 0, // 5: com.diulo.api.user.login1:input_type -> com.diulo.api.request + 1, // 6: com.diulo.api.user.list:output_type -> com.diulo.api.response + 1, // 7: com.diulo.api.user.all:output_type -> com.diulo.api.response + 1, // 8: com.diulo.api.user.auto:output_type -> com.diulo.api.response + 1, // 9: com.diulo.api.user.loginWithList:output_type -> com.diulo.api.response + 1, // 10: com.diulo.api.user.login:output_type -> com.diulo.api.response + 1, // 11: com.diulo.api.user.login1:output_type -> com.diulo.api.response + 6, // [6:12] is the sub-list for method output_type + 0, // [0:6] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name diff --git a/api/user.ts b/api/user.ts index f63e1ee..dff6deb 100644 --- a/api/user.ts +++ b/api/user.ts @@ -46,4 +46,11 @@ export class userService{ method:'GET' }) } + static async login1(data :request, param?: Config):Promise{ + return http('/api/v1/user/login1', { + ...param, + data: data, + method:'POST' + }) + } } diff --git a/api/user_grpc.pb.go b/api/user_grpc.pb.go index 5b6c20a..ac2d83c 100644 --- a/api/user_grpc.pb.go +++ b/api/user_grpc.pb.go @@ -24,6 +24,7 @@ const ( User_Auto_FullMethodName = "/com.diulo.api.user/auto" User_LoginWithList_FullMethodName = "/com.diulo.api.user/loginWithList" User_Login_FullMethodName = "/com.diulo.api.user/login" + User_Login1_FullMethodName = "/com.diulo.api.user/login1" ) // UserClient is the client API for User service. @@ -37,6 +38,7 @@ type UserClient interface { LoginWithList(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) // 没有 "user:list" 权限 Login(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) + Login1(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) } type userClient struct { @@ -92,6 +94,15 @@ func (c *userClient) Login(ctx context.Context, in *Request, opts ...grpc.CallOp return out, nil } +func (c *userClient) Login1(ctx context.Context, in *Request, opts ...grpc.CallOption) (*Response, error) { + out := new(Response) + err := c.cc.Invoke(ctx, User_Login1_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // UserServer is the server API for User service. // All implementations must embed UnimplementedUserServer // for forward compatibility @@ -103,6 +114,7 @@ type UserServer interface { LoginWithList(context.Context, *Request) (*Response, error) // 没有 "user:list" 权限 Login(context.Context, *Request) (*Response, error) + Login1(context.Context, *Request) (*Response, error) mustEmbedUnimplementedUserServer() } @@ -125,6 +137,9 @@ func (UnimplementedUserServer) LoginWithList(context.Context, *Request) (*Respon func (UnimplementedUserServer) Login(context.Context, *Request) (*Response, error) { return nil, status.Errorf(codes.Unimplemented, "method Login not implemented") } +func (UnimplementedUserServer) Login1(context.Context, *Request) (*Response, error) { + return nil, status.Errorf(codes.Unimplemented, "method Login1 not implemented") +} func (UnimplementedUserServer) mustEmbedUnimplementedUserServer() {} // UnsafeUserServer may be embedded to opt out of forward compatibility for this service. @@ -228,6 +243,24 @@ func _User_Login_Handler(srv interface{}, ctx context.Context, dec func(interfac return interceptor(ctx, in, info, handler) } +func _User_Login1_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Request) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(UserServer).Login1(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: User_Login1_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(UserServer).Login1(ctx, req.(*Request)) + } + return interceptor(ctx, in, info, handler) +} + // User_ServiceDesc is the grpc.ServiceDesc for User service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -255,6 +288,10 @@ var User_ServiceDesc = grpc.ServiceDesc{ MethodName: "login", Handler: _User_Login_Handler, }, + { + MethodName: "login1", + Handler: _User_Login1_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "user.proto", diff --git a/api/user_http.pb.go b/api/user_http.pb.go index 434debc..4aa8306 100644 --- a/api/user_http.pb.go +++ b/api/user_http.pb.go @@ -1,9 +1,9 @@ package user import ( - "git.diulo.com/mogfee/kit/transport/http" - "git.diulo.com/mogfee/kit/middleware/jwt" "context" + "git.diulo.com/mogfee/kit/middleware/jwt" + "git.diulo.com/mogfee/kit/transport/http" ) type UserHTTPServer interface { @@ -12,6 +12,7 @@ type UserHTTPServer interface { Auto(context.Context, *Request) (*Response, error) LoginWithList(context.Context, *Request) (*Response, error) Login(context.Context, *Request) (*Response, error) + Login1(context.Context, *Request) (*Response, error) } func RegisterUserHTTPServer(s *http.Server, srv UserServer) { @@ -21,6 +22,7 @@ func RegisterUserHTTPServer(s *http.Server, srv UserServer) { r.GET("/api/v1/user/auto", _User_Auto0_HTTP_Handler(srv)) r.GET("/api/v1/user/login_list", _User_LoginWithList0_HTTP_Handler(srv)) r.GET("/api/v1/user/login", _User_Login0_HTTP_Handler(srv)) + r.POST("/api/v1/user/login1", _User_Login10_HTTP_Handler(srv)) } func _User_List0_HTTP_Handler(srv UserHTTPServer) func(ctx http.Context) error { return func(ctx http.Context) error { @@ -121,3 +123,22 @@ func _User_Login0_HTTP_Handler(srv UserHTTPServer) func(ctx http.Context) error return ctx.Result(200, reply) } } +func _User_Login10_HTTP_Handler(srv UserHTTPServer) func(ctx http.Context) error { + return func(ctx http.Context) error { + var in Request + var newCtx context.Context = ctx + if err := ctx.Bind(&in); err != nil { + return err + } + http.SetOperation(ctx, "/com.diulo.api.user/login1") + h := ctx.Middleware(func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.Login1(ctx, req.(*Request)) + }) + out, err := h(newCtx, &in) + if err != nil { + return err + } + reply := out.(*Response) + return ctx.Result(200, reply) + } +} diff --git a/example/main.go b/example/main.go index 77040e5..0f0bb38 100644 --- a/example/main.go +++ b/example/main.go @@ -1,93 +1,26 @@ package main import ( - "context" "flag" "fmt" "git.diulo.com/mogfee/kit" user "git.diulo.com/mogfee/kit/api" - "git.diulo.com/mogfee/kit/errors" "git.diulo.com/mogfee/kit/example/service" "git.diulo.com/mogfee/kit/middleware/jwt" "git.diulo.com/mogfee/kit/transport/http" ) -var host string - -func init() { - flag.StringVar(&host, "h", "localhost:9922", "") -} - func main() { flag.Parse() - runApp(host) + runApp("localhost:8080") } func runApp(host string) { - tokenKey := "1234567890123456" - hs := http.NewServer( http.Address(host), - //http.Middleware( - //logging.Server(), - //validate.Server(), - //jwt.JWT(), - //), - //http.Filter(func(handler http2.Handler) http2.Handler { - // return &cros{ - // h: handler, - // } - //}), - http.Middleware( - jwt.JWT(jwt.WithJwtKey(tokenKey), jwt.WithFromKey("query:token")), - ), - ) - route := hs.Route("/") - route.GET("/api/v1/answer/listCategory", func(ctx http.Context) error { - in := UserAddRequest{Name: "tom"} - http.SetOperation(ctx, "/api/abc") - h := ctx.Middleware(func(ctx context.Context, a any) (any, error) { - //tr, _ := transport.FromServerContext(ctx) - //return AddUser(ctx, a.(*UserAddRequest)) - return &UserAddResponse{Id: "xxx"}, nil - return nil, errors.New(400, "BAD", "a") - }) - out, err := h(ctx, &in) - if err != nil { - return err - } - reply, _ := out.(*UserAddResponse) - reply.Id = host - return ctx.Result(200, reply) - }) - - user.RegisterUserHTTPServer(hs, service.NewUserService(tokenKey)) - //client, err := clientv3.New(clientv3.Config{ - // Endpoints: []string{"127.0.0.1:2379"}, - //}) - //if err != nil { - // panic(err) - //} - app := kit.New( - kit.Name("kit-server"), - kit.Version("v1.0"), - kit.ID(host), - kit.Server(hs), - //kit.Registrar(etcd.New(client)), + http.Middleware(jwt.JWT(jwt.WithJwtKey("Jt6Zv!KTopXZ6S4C"))), ) + user.RegisterUserHTTPServer(hs, service.NewUserService("1234567890123456")) + app := kit.New(kit.Server(hs)) fmt.Println(app.Run()) fmt.Println(app.Stop()) } - -type UserAddRequest struct { - Name string -} -type UserAddResponse struct { - Id string -} - -func AddUser(ctx context.Context, request *UserAddRequest) (*UserAddResponse, error) { - //fmt.Println(jwt.FromUserContext(ctx)) - //fmt.Println(jwt.FromAuthKeyContext(ctx)) - return &UserAddResponse{Id: request.Name}, nil - //errors.New(500, "xx", "") -} diff --git a/example/service/service.go b/example/service/service.go index e4e88c0..de21034 100644 --- a/example/service/service.go +++ b/example/service/service.go @@ -70,3 +70,16 @@ func (s *userServer) Login(ctx context.Context, request *user.Request) (*user.Re } return &user.Response{Token: token}, nil } + +func (s *userServer) Login1(ctx context.Context, request *user.Request) (*user.Response, error) { + token, _, err := jwt.GetToken(s.tokenKey, &jwt.UserInfo{ + UserId: 1, + UserName: "test", + UserType: "user", + UniqueId: "", + }) + if err != nil { + return nil, err + } + return &user.Response{Token: token}, nil +} diff --git a/middleware/cc/cc.go b/middleware/cc/cc.go new file mode 100644 index 0000000..7d4236a --- /dev/null +++ b/middleware/cc/cc.go @@ -0,0 +1,356 @@ +package cc + +import ( + "fmt" + "net/http" + "strconv" + "strings" +) + +// CORSOption represents a functional option for configuring the CORS middleware. +type CORSOption func(*cors) error + +type cors struct { + h http.Handler + allowedHeaders []string + allowedMethods []string + allowedOrigins []string + allowedOriginValidator OriginValidator + exposedHeaders []string + maxAge int + ignoreOptions bool + allowCredentials bool + optionStatusCode int +} + +// OriginValidator takes an origin string and returns whether or not that origin is allowed. +type OriginValidator func(string) bool + +var ( + defaultCorsOptionStatusCode = 200 + defaultCorsMethods = []string{"GET", "HEAD", "POST"} + defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} + // (WebKit/Safari v9 sends the Origin header by default in AJAX requests) +) + +const ( + corsOptionMethod string = "OPTIONS" + corsAllowOriginHeader string = "Access-Control-Allow-Origin" + corsExposeHeadersHeader string = "Access-Control-Expose-Headers" + corsMaxAgeHeader string = "Access-Control-Max-Age" + corsAllowMethodsHeader string = "Access-Control-Allow-Methods" + corsAllowHeadersHeader string = "Access-Control-Allow-Headers" + corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials" + corsRequestMethodHeader string = "Access-Control-Request-Method" + corsRequestHeadersHeader string = "Access-Control-Request-Headers" + corsOriginHeader string = "Origin" + corsVaryHeader string = "Vary" + corsOriginMatchAll string = "*" +) + +func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get(corsOriginHeader) + if !ch.isOriginAllowed(origin) { + if r.Method != corsOptionMethod || ch.ignoreOptions { + ch.h.ServeHTTP(w, r) + } + + return + } + + if r.Method == corsOptionMethod { + if ch.ignoreOptions { + ch.h.ServeHTTP(w, r) + return + } + + if _, ok := r.Header[corsRequestMethodHeader]; !ok { + w.WriteHeader(http.StatusBadRequest) + return + } + + method := r.Header.Get(corsRequestMethodHeader) + if !ch.isMatch(method, ch.allowedMethods) { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",") + allowedHeaders := []string{} + for _, v := range requestHeaders { + canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) + if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) { + continue + } + + if !ch.isMatch(canonicalHeader, ch.allowedHeaders) { + w.WriteHeader(http.StatusForbidden) + return + } + + allowedHeaders = append(allowedHeaders, canonicalHeader) + } + + if len(allowedHeaders) > 0 { + w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) + } + + if ch.maxAge > 0 { + w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge)) + } + + if !ch.isMatch(method, defaultCorsMethods) { + w.Header().Set(corsAllowMethodsHeader, method) + } + } else { + if len(ch.exposedHeaders) > 0 { + w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) + } + } + + if ch.allowCredentials { + w.Header().Set(corsAllowCredentialsHeader, "true") + } + + if len(ch.allowedOrigins) > 1 { + w.Header().Set(corsVaryHeader, corsOriginHeader) + } + + returnOrigin := origin + if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 { + returnOrigin = "*" + } else { + for _, o := range ch.allowedOrigins { + // A configuration of * is different than explicitly setting an allowed + // origin. Returning arbitrary origin headers in an access control allow + // origin header is unsafe and is not required by any use case. + if o == corsOriginMatchAll { + returnOrigin = "*" + break + } + } + } + w.Header().Set(corsAllowOriginHeader, returnOrigin) + + if r.Method == corsOptionMethod { + w.WriteHeader(ch.optionStatusCode) + return + } + fmt.Println("222") + ch.h.ServeHTTP(w, r) +} + +// CORS provides Cross-Origin Resource Sharing middleware. +// Example: +// +// import ( +// "net/http" +// +// "github.com/gorilla/handlers" +// "github.com/gorilla/mux" +// ) +// +// func main() { +// r := mux.NewRouter() +// r.HandleFunc("/users", UserEndpoint) +// r.HandleFunc("/projects", ProjectEndpoint) +// +// // Apply the CORS middleware to our top-level router, with the defaults. +// http.ListenAndServe(":8000", handlers.CORS()(r)) +// } +func CORS(opts ...CORSOption) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + ch := parseCORSOptions(opts...) + ch.h = h + return ch + } +} + +func parseCORSOptions(opts ...CORSOption) *cors { + ch := &cors{ + allowedMethods: defaultCorsMethods, + allowedHeaders: defaultCorsHeaders, + allowedOrigins: []string{}, + optionStatusCode: defaultCorsOptionStatusCode, + } + + for _, option := range opts { + option(ch) + } + + return ch +} + +// +// Functional options for configuring CORS. +// + +// AllowedHeaders adds the provided headers to the list of allowed headers in a +// CORS request. +// This is an append operation so the headers Accept, Accept-Language, +// and Content-Language are always allowed. +// Content-Type must be explicitly declared if accepting Content-Types other than +// application/x-www-form-urlencoded, multipart/form-data, or text/plain. +func AllowedHeaders(headers []string) CORSOption { + return func(ch *cors) error { + for _, v := range headers { + normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) + if normalizedHeader == "" { + continue + } + + if !ch.isMatch(normalizedHeader, ch.allowedHeaders) { + ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader) + } + } + + return nil + } +} + +// AllowedMethods can be used to explicitly allow methods in the +// Access-Control-Allow-Methods header. +// This is a replacement operation so you must also +// pass GET, HEAD, and POST if you wish to support those methods. +func AllowedMethods(methods []string) CORSOption { + return func(ch *cors) error { + ch.allowedMethods = []string{} + for _, v := range methods { + normalizedMethod := strings.ToUpper(strings.TrimSpace(v)) + if normalizedMethod == "" { + continue + } + + if !ch.isMatch(normalizedMethod, ch.allowedMethods) { + ch.allowedMethods = append(ch.allowedMethods, normalizedMethod) + } + } + + return nil + } +} + +// AllowedOrigins sets the allowed origins for CORS requests, as used in the +// 'Allow-Access-Control-Origin' HTTP header. +// Note: Passing in a []string{"*"} will allow any domain. +func AllowedOrigins(origins []string) CORSOption { + return func(ch *cors) error { + for _, v := range origins { + if v == corsOriginMatchAll { + ch.allowedOrigins = []string{corsOriginMatchAll} + return nil + } + } + + ch.allowedOrigins = origins + return nil + } +} + +// AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the +// 'Allow-Access-Control-Origin' HTTP header. +func AllowedOriginValidator(fn OriginValidator) CORSOption { + return func(ch *cors) error { + ch.allowedOriginValidator = fn + return nil + } +} + +// OptionStatusCode sets a custom status code on the OPTIONS requests. +// Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory +// and can be used if you need a custom status code (i.e 204). +// +// More informations on the spec: +// https://fetch.spec.whatwg.org/#cors-preflight-fetch +func OptionStatusCode(code int) CORSOption { + return func(ch *cors) error { + ch.optionStatusCode = code + return nil + } +} + +// ExposedHeaders can be used to specify headers that are available +// and will not be stripped out by the user-agent. +func ExposedHeaders(headers []string) CORSOption { + return func(ch *cors) error { + ch.exposedHeaders = []string{} + for _, v := range headers { + normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) + if normalizedHeader == "" { + continue + } + + if !ch.isMatch(normalizedHeader, ch.exposedHeaders) { + ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader) + } + } + + return nil + } +} + +// MaxAge determines the maximum age (in seconds) between preflight requests. A +// maximum of 10 minutes is allowed. An age above this value will default to 10 +// minutes. +func MaxAge(age int) CORSOption { + return func(ch *cors) error { + // Maximum of 10 minutes. + if age > 600 { + age = 600 + } + + ch.maxAge = age + return nil + } +} + +// IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead +// passing them through to the next handler. This is useful when your application +// or framework has a pre-existing mechanism for responding to OPTIONS requests. +func IgnoreOptions() CORSOption { + return func(ch *cors) error { + ch.ignoreOptions = true + return nil + } +} + +// AllowCredentials can be used to specify that the user agent may pass +// authentication details along with the request. +func AllowCredentials() CORSOption { + return func(ch *cors) error { + ch.allowCredentials = true + return nil + } +} + +func (ch *cors) isOriginAllowed(origin string) bool { + if origin == "" { + return false + } + + if ch.allowedOriginValidator != nil { + return ch.allowedOriginValidator(origin) + } + + if len(ch.allowedOrigins) == 0 { + return true + } + + for _, allowedOrigin := range ch.allowedOrigins { + if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll { + return true + } + } + + return false +} + +func (ch *cors) isMatch(needle string, haystack []string) bool { + for _, v := range haystack { + if v == needle { + return true + } + } + + return false +} diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index e083c45..d10b16c 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -2,8 +2,10 @@ package cors import ( "context" + "fmt" "git.diulo.com/mogfee/kit/middleware" "git.diulo.com/mogfee/kit/transport" + "net/http" ) type OptionFunc func(o *option) @@ -42,6 +44,7 @@ func Cors(ops ...OptionFunc) middleware.Middleware { } return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, a any) (any, error) { + fmt.Println("=======") if tr, ok := transport.FromServerContext(ctx); ok { header := tr.ReplyHeader() header.Set("Access-Control-Allow-Origin", cfg.domain) @@ -53,3 +56,20 @@ func Cors(ops ...OptionFunc) middleware.Middleware { } } } + +func HttpServer() func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + ch := &cors{h: h} + ch.h = h + return ch + } +} + +type cors struct { + h http.Handler +} + +func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { + + ch.h.ServeHTTP(w, r) +} diff --git a/middleware/jwt/token.go b/middleware/jwt/token.go index 7313212..8b8792a 100644 --- a/middleware/jwt/token.go +++ b/middleware/jwt/token.go @@ -65,7 +65,10 @@ func Parse(key string, tokenStr string) (*UserInfo, error) { return []byte(key), nil }) if err != nil { - return nil, err + if errors.Is(err, jwt.ErrTokenExpired) { + return nil, errors.Unauthorized("TOKEN_EXPIRED", "") + } + return nil, errors.Unauthorized("TOKEN_ERROR", err.Error()) } if token.Valid { diff --git a/proto/user.proto b/proto/user.proto index 0d9161e..6a14d56 100644 --- a/proto/user.proto +++ b/proto/user.proto @@ -43,6 +43,13 @@ service user{ get: "/api/v1/user/login", }; } + + rpc login1(request)returns(response){ + option (google.api.http) = { + post: "/api/v1/user/login1", + body:"*", + }; + } } message request{ }