You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

313 lines
7.2 KiB

2 years ago
package http
import (
"context"
"crypto/tls"
2 years ago
"errors"
"git.diulo.com/mogfee/kit/internal/endpoint"
"git.diulo.com/mogfee/kit/internal/host"
2 years ago
"git.diulo.com/mogfee/kit/internal/matcher"
2 years ago
"git.diulo.com/mogfee/kit/log"
2 years ago
"git.diulo.com/mogfee/kit/middleware"
"git.diulo.com/mogfee/kit/transport"
"github.com/gorilla/mux"
"net"
"net/http"
"net/url"
"time"
)
var (
_ transport.Server = (*Server)(nil)
_ transport.Endpointer = (*Server)(nil)
_ http.Handler = (*Server)(nil)
)
type ServerOption func(server *Server)
// Network with server network.
func Network(network string) ServerOption {
return func(s *Server) {
s.network = network
}
}
// Address with server address.
func Address(addr string) ServerOption {
return func(s *Server) {
s.address = addr
}
}
// Timeout with server timeout.
func Timeout(timeout time.Duration) ServerOption {
return func(s *Server) {
s.timeout = timeout
}
}
// Logger with server logger.
// Deprecated: use global logger instead.
func Logger(logger log.Logger) ServerOption {
return func(s *Server) {}
}
// Middleware with service middleware option.
func Middleware(m ...middleware.Middleware) ServerOption {
return func(o *Server) {
o.middleware.Use(m...)
}
}
// Filter with HTTP middleware option.
func Filter(filters ...FilterFunc) ServerOption {
return func(o *Server) {
o.filters = filters
}
}
// RequestVarsDecoder with request decoder.
func RequestVarsDecoder(dec DecodeRequestFunc) ServerOption {
return func(o *Server) {
o.decVars = dec
}
}
// RequestQueryDecoder with request decoder.
func RequestQueryDecoder(dec DecodeRequestFunc) ServerOption {
return func(o *Server) {
o.decQuery = dec
}
}
// RequestDecoder with request decoder.
func RequestDecoder(dec DecodeRequestFunc) ServerOption {
return func(o *Server) {
o.decBody = dec
}
}
// ResponseEncoder with response encoder.
func ResponseEncoder(en EncodeResponseFunc) ServerOption {
return func(o *Server) {
o.enc = en
}
}
// ErrorEncoder with error encoder.
func ErrorEncoder(en EncodeErrorFunc) ServerOption {
return func(o *Server) {
o.ene = en
}
}
// TLSConfig with TLS config.
func TLSConfig(c *tls.Config) ServerOption {
return func(o *Server) {
o.tlsConf = c
}
}
// StrictSlash is with mux's StrictSlash
// If true, when the path pattern is "/path/", accessing "/path" will
// redirect to the former and vice versa.
func StrictSlash(strictSlash bool) ServerOption {
return func(o *Server) {
o.strictSlash = strictSlash
}
}
// Listener with server lis
func Listener(lis net.Listener) ServerOption {
return func(s *Server) {
s.lis = lis
}
}
// PathPrefix with mux's PathPrefix, router will replaced by a subrouter that start with prefix.
func PathPrefix(prefix string) ServerOption {
return func(s *Server) {
s.router = s.router.PathPrefix(prefix).Subrouter()
}
}
type Server struct {
*http.Server
lis net.Listener
tlsConf *tls.Config
endpoint *url.URL
err error
network string
address string
timeout time.Duration
filters []FilterFunc
middleware matcher.Matcher
decVars DecodeRequestFunc
decQuery DecodeRequestFunc
decBody DecodeRequestFunc
enc EncodeResponseFunc
ene EncodeErrorFunc
strictSlash bool
router *mux.Router
}
func NewServer(opts ...ServerOption) *Server {
srv := &Server{
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
middleware: matcher.New(),
decVars: DefaultRequestVars,
decQuery: DefaultRequestQuery,
decBody: DefaultRequestDecoder,
2 years ago
enc: DefaultResponseEncoder,
ene: DefaultErrorEncoder,
2 years ago
strictSlash: true,
router: mux.NewRouter(),
}
for _, o := range opts {
o(srv)
}
srv.router.StrictSlash(srv.strictSlash)
srv.router.NotFoundHandler = http.DefaultServeMux
srv.router.MethodNotAllowedHandler = http.DefaultServeMux
srv.router.Use(srv.filter())
srv.Server = &http.Server{
Handler: FilterChain(srv.filters...)(srv.router),
TLSConfig: srv.tlsConf,
}
return srv
}
func (s *Server) Use(selector string, m ...middleware.Middleware) {
s.middleware.Add(selector, m...)
}
func (s *Server) WalkRoute(fn WalkRoutFunc) error {
return s.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
methods, err := route.GetMethods()
if err != nil {
return err
}
path, err := route.GetPathTemplate()
if err != nil {
return err
}
for _, method := range methods {
if err = fn(RouteInfo{
Method: method,
Path: path,
}); err != nil {
return err
}
}
return nil
})
}
func (s *Server) Kind() transport.Kind {
2 years ago
return transport.KindHTTP
2 years ago
}
func (s *Server) Route(prefix string, filters ...FilterFunc) *Router {
return newRouter(prefix, s, filters...)
}
func (s *Server) Handle(path string, h http.Handler) {
s.router.Handle(path, h)
}
func (s *Server) HandlePrefix(prefix string, h http.Handler) {
s.router.PathPrefix(prefix).Handler(h)
}
func (s *Server) HandleFunc(path string, h http.HandlerFunc) {
s.router.HandleFunc(path, h)
}
func (s *Server) HandleHeader(key, val string, h http.Handler) {
s.router.Headers(key, val).Handler(h)
}
func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) {
s.Handler.ServeHTTP(res, req)
}
2 years ago
func (s *Server) Endpoint() (*url.URL, error) {
if err := s.listenAndEndpoint(); err != nil {
return nil, err
}
return s.endpoint, nil
}
func (s *Server) Start(ctx context.Context) error {
if err := s.listenAndEndpoint(); err != nil {
return err
}
s.BaseContext = func(listener net.Listener) context.Context {
return ctx
}
log.Infof("[HTTP] server listening on: %s", s.endpoint)
var err error
if s.tlsConf != nil {
err = s.ServeTLS(s.lis, "", "")
} else {
err = s.Serve(s.lis)
}
if !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}
func (s *Server) Stop(ctx context.Context) error {
log.Info("[HTTP] server stopping")
return s.Shutdown(ctx)
}
2 years ago
func (s *Server) filter() mux.MiddlewareFunc {
return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
var (
ctx context.Context
cancel context.CancelFunc
)
if s.timeout > 0 {
ctx, cancel = context.WithTimeout(request.Context(), s.timeout)
} else {
ctx, cancel = context.WithCancel(request.Context())
}
defer cancel()
pathTemplate := request.URL.Path
if route := mux.CurrentRoute(request); route != nil {
pathTemplate, _ = route.GetPathTemplate()
}
2 years ago
tr := &Transport{
operation: pathTemplate,
pathTemplate: pathTemplate,
reqHeader: headerCarrier(request.Header),
replyHeader: headerCarrier(writer.Header()),
request: request,
}
if s.endpoint != nil {
tr.endpoint = s.endpoint.String()
}
tr.request = request.WithContext(transport.NewServerContext(ctx, tr))
handler.ServeHTTP(writer, request)
2 years ago
})
}
}
2 years ago
func (s *Server) listenAndEndpoint() error {
if s.lis == nil {
lis, err := net.Listen(s.network, s.address)
if err != nil {
s.err = err
return err
}
s.lis = lis
}
if s.endpoint == nil {
addr, err := host.Extract(s.address, s.lis)
if err != nil {
s.err = err
return err
}
s.endpoint = endpoint.NewEndpoint(endpoint.Scheme("http", s.tlsConf != nil), addr)
}
return s.err
2 years ago
}