package http import ( "context" "crypto/tls" "git.diulo.com/mogfee/kit/internal/matcher" "git.diulo.com/mogfee/kit/middleware" "git.diulo.com/mogfee/kit/transport" "github.com/gorilla/mux" "net" "net/http" "net/url" "time" ) var ( _ transport.Server = (*Server)(nil) _ transport.Endpointer = (*Server)(nil) _ http.Handler = (*Server)(nil) ) type ServerOption func(server *Server) // Network with server network. func Network(network string) ServerOption { return func(s *Server) { s.network = network } } // Address with server address. func Address(addr string) ServerOption { return func(s *Server) { s.address = addr } } // Timeout with server timeout. func Timeout(timeout time.Duration) ServerOption { return func(s *Server) { s.timeout = timeout } } // Logger with server logger. // Deprecated: use global logger instead. func Logger(logger log.Logger) ServerOption { return func(s *Server) {} } // Middleware with service middleware option. func Middleware(m ...middleware.Middleware) ServerOption { return func(o *Server) { o.middleware.Use(m...) } } // Filter with HTTP middleware option. func Filter(filters ...FilterFunc) ServerOption { return func(o *Server) { o.filters = filters } } // RequestVarsDecoder with request decoder. func RequestVarsDecoder(dec DecodeRequestFunc) ServerOption { return func(o *Server) { o.decVars = dec } } // RequestQueryDecoder with request decoder. func RequestQueryDecoder(dec DecodeRequestFunc) ServerOption { return func(o *Server) { o.decQuery = dec } } // RequestDecoder with request decoder. func RequestDecoder(dec DecodeRequestFunc) ServerOption { return func(o *Server) { o.decBody = dec } } // ResponseEncoder with response encoder. func ResponseEncoder(en EncodeResponseFunc) ServerOption { return func(o *Server) { o.enc = en } } // ErrorEncoder with error encoder. func ErrorEncoder(en EncodeErrorFunc) ServerOption { return func(o *Server) { o.ene = en } } // TLSConfig with TLS config. func TLSConfig(c *tls.Config) ServerOption { return func(o *Server) { o.tlsConf = c } } // StrictSlash is with mux's StrictSlash // If true, when the path pattern is "/path/", accessing "/path" will // redirect to the former and vice versa. func StrictSlash(strictSlash bool) ServerOption { return func(o *Server) { o.strictSlash = strictSlash } } // Listener with server lis func Listener(lis net.Listener) ServerOption { return func(s *Server) { s.lis = lis } } // PathPrefix with mux's PathPrefix, router will replaced by a subrouter that start with prefix. func PathPrefix(prefix string) ServerOption { return func(s *Server) { s.router = s.router.PathPrefix(prefix).Subrouter() } } type Server struct { *http.Server lis net.Listener tlsConf *tls.Config endpoint *url.URL err error network string address string timeout time.Duration filters []FilterFunc middleware matcher.Matcher decVars DecodeRequestFunc decQuery DecodeRequestFunc decBody DecodeRequestFunc enc EncodeResponseFunc ene EncodeErrorFunc strictSlash bool router *mux.Router } func NewServer(opts ...ServerOption) *Server { srv := &Server{ network: "tcp", address: ":0", timeout: 1 * time.Second, middleware: matcher.New(), decVars: DefaultRequestVars, decQuery: DefaultRequestQuery, decBody: DefaultRequestDecoder, strictSlash: true, router: mux.NewRouter(), } for _, o := range opts { o(srv) } srv.router.StrictSlash(srv.strictSlash) srv.router.NotFoundHandler = http.DefaultServeMux srv.router.MethodNotAllowedHandler = http.DefaultServeMux srv.router.Use(srv.filter()) srv.Server = &http.Server{ Handler: FilterChain(srv.filters...)(srv.router), TLSConfig: srv.tlsConf, } return srv } func (s *Server) Use(selector string, m ...middleware.Middleware) { s.middleware.Add(selector, m...) } func (s *Server) WalkRoute(fn WalkRoutFunc) error { return s.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { methods, err := route.GetMethods() if err != nil { return err } path, err := route.GetPathTemplate() if err != nil { return err } for _, method := range methods { if err = fn(RouteInfo{ Method: method, Path: path, }); err != nil { return err } } return nil }) } func (s *Server) Kind() transport.Kind { return transport.htt } func (s *Server) Route(prefix string, filters ...FilterFunc) *Router { return newRouter(prefix, s, filters...) } func (s *Server) Handle(path string, h http.Handler) { s.router.Handle(path, h) } func (s *Server) HandlePrefix(prefix string, h http.Handler) { s.router.PathPrefix(prefix).Handler(h) } func (s *Server) HandleFunc(path string, h http.HandlerFunc) { s.router.HandleFunc(path, h) } func (s *Server) HandleHeader(key, val string, h http.Handler) { s.router.Headers(key, val).Handler(h) } func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { s.Handler.ServeHTTP(res, req) } func (s *Server) filter() mux.MiddlewareFunc { return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { var ( ctx context.Context cancel context.CancelFunc ) if s.timeout > 0 { ctx, cancel = context.WithTimeout(request.Context(), s.timeout) } else { ctx, cancel = context.WithCancel(request.Context()) } defer cancel() pathTemplate := request.URL.Path if route := mux.CurrentRoute(request); route != nil { pathTemplate, _ = route.GetPathTemplate() } tr := &Transport{} }) } } func (s *Server) Endpoint() string { //TODO implement me panic("implement me") } func (s *Server) Operation() string { //TODO implement me panic("implement me") } func (s *Server) RequestHeader() transport.Header { //TODO implement me panic("implement me") } func (s *Server) ReplyHeader() transport.Header { //TODO implement me panic("implement me") }