From 75c9c1c909413cfed261c380011c50c6c0645482 Mon Sep 17 00:00:00 2001 From: mogfee Date: Sat, 4 Mar 2023 18:22:32 +0800 Subject: [PATCH] x --- app.go | 88 ++++++++++- cmd/kit/main.go | 8 +- encoding/encoding.go | 25 ++++ encoding/form/form.go | 80 ++++++++++ {xerrors => errors}/errors.go | 2 +- {xerrors => errors}/types.go | 2 +- example/service/service.go | 2 +- go.mod | 3 +- go.sum | 2 + internal/httputil/http.go | 26 ++++ internal/matcher/matcher.go | 56 +++++++ options.go | 4 +- registry/registry.go | 65 ++++++++ test/main.go | 16 ++ transport/http/binding/bind.go | 25 ++++ transport/http/codec.go | 95 ++++++++++++ transport/http/context.go | 201 +++++++++++++++++++++++++ transport/http/filter.go | 14 ++ transport/http/router.go | 59 ++++++++ transport/http/server.go | 263 +++++++++++++++++++++++++++++++++ transport/http/transport.go | 65 ++++++++ transport/transport.go | 58 ++++++++ 22 files changed, 1149 insertions(+), 10 deletions(-) create mode 100644 encoding/encoding.go create mode 100644 encoding/form/form.go rename {xerrors => errors}/errors.go (98%) rename {xerrors => errors}/types.go (99%) create mode 100644 internal/httputil/http.go create mode 100644 internal/matcher/matcher.go create mode 100644 registry/registry.go create mode 100644 test/main.go create mode 100644 transport/http/binding/bind.go create mode 100644 transport/http/codec.go create mode 100644 transport/http/context.go create mode 100644 transport/http/filter.go create mode 100644 transport/http/router.go create mode 100644 transport/http/server.go create mode 100644 transport/http/transport.go create mode 100644 transport/transport.go diff --git a/app.go b/app.go index 0cb5492..2820551 100644 --- a/app.go +++ b/app.go @@ -2,10 +2,14 @@ package protoc_gen_kit import ( "context" + "errors" "git.diulo.com/mogfee/protoc-gen-kit/log" + "git.diulo.com/mogfee/protoc-gen-kit/registry" + "git.diulo.com/mogfee/protoc-gen-kit/transport" "github.com/google/uuid" "golang.org/x/sync/errgroup" "os" + "os/signal" "sync" "syscall" "time" @@ -90,7 +94,7 @@ func (a *App) Run() error { srv := srv eg.Go(func() error { <-ctx.Done() - stopCtx, cancel := context.WithTimeout(NewContext(a.opts.ctx, a)) + stopCtx, cancel := context.WithTimeout(NewContext(a.opts.ctx, a), a.opts.stopTimeout) defer cancel() return srv.Stop(stopCtx) }) @@ -101,5 +105,87 @@ func (a *App) Run() error { }) } wg.Wait() + if a.opts.registrar != nil { + rctx, rcancel := context.WithTimeout(ctx, a.opts.registrarTimeout) + defer rcancel() + if err = a.opts.registrar.Register(rctx, instance); err != nil { + return err + } + } + for _, fn := range a.opts.afterStart { + if err = fn(sctx); err != nil { + return err + } + } + c := make(chan os.Signal, 1) + signal.Notify(c, a.opts.sigs...) + eg.Go(func() error { + select { + case <-ctx.Done(): + return nil + case <-c: + return a.Stop() + } + }) + if err = eg.Wait(); err != nil && !errors.Is(err, context.Canceled) { + for _, fn := range a.opts.afterStop { + err = fn(sctx) + } + } + return err +} +func (a *App) Stop() (err error) { + sctx := NewContext(a.ctx, a) + for _, fn := range a.opts.beforeStop { + err = fn(sctx) + } + a.mu.Lock() + instance := a.instance + a.mu.Unlock() + if a.opts.registrar != nil && instance != nil { + ctx, cancel := context.WithTimeout(NewContext(a.ctx, a), a.opts.stopTimeout) + defer cancel() + if err = a.opts.registrar.Deregister(ctx, instance); err != nil { + return err + } + } + if a.cancel != nil { + a.cancel() + } + return err +} +func (a *App) buildInstance() (*registry.ServiceInstance, error) { + endpoints := make([]string, 0, len(a.opts.endpoints)) + for _, e := range a.opts.endpoints { + endpoints = append(endpoints, e.String()) + } + if len(endpoints) == 0 { + for _, srv := range a.opts.servers { + if r, ok := srv.(transport.Endpointer); ok { + e, err := r.Endpoint() + if err != nil { + return nil, err + } + endpoints = append(endpoints, e.String()) + } + } + } + return ®istry.ServiceInstance{ + ID: a.opts.id, + Name: a.opts.name, + Version: a.opts.version, + Metadata: a.opts.metadata, + Endpoints: endpoints, + }, nil +} +type appKey struct { +} + +func NewContext(ctx context.Context, s AppInfo) context.Context { + return context.WithValue(ctx, appKey{}, s) +} +func FromContext(ctx context.Context) (s AppInfo, ok bool) { + s, ok = ctx.Value(appKey{}).(AppInfo) + return } diff --git a/cmd/kit/main.go b/cmd/kit/main.go index 7ee4218..9ccb233 100644 --- a/cmd/kit/main.go +++ b/cmd/kit/main.go @@ -26,7 +26,7 @@ func (u *Kit) Generate(plugin *protogen.Plugin) error { u.addImports("context") u.addImports("git.diulo.com/mogfee/protoc-gen-kit/middleware") u.addImports("git.diulo.com/mogfee/protoc-gen-kit/pkg/response") - u.addImports("git.diulo.com/mogfee/protoc-gen-kit/pkg/xerrors") + u.addImports("git.diulo.com/mogfee/protoc-gen-kit/pkg/errors") u.addImports("github.com/gin-gonic/gin") for _, f := range plugin.Files { if len(f.Services) == 0 { @@ -94,7 +94,7 @@ func (u *Kit) genGet(serverName string, t *protogen.GeneratedFile, m *protogen.M if v, ok := out.(*`, m.Output.GoIdent.GoName, `); ok { resp.Success(v) } else { - resp.Error(xerrors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `")) + resp.Error(errors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `")) } } }`) @@ -119,7 +119,7 @@ func (u *Kit) genPost(serverName string, t *protogen.GeneratedFile, m *protogen. if v, ok := out.(*`, m.Output.GoIdent.GoName, `); ok { resp.Success(v) } else { - resp.Error(xerrors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `")) + resp.Error(errors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `")) } } }`) @@ -144,7 +144,7 @@ func (u *Kit) genDelete(serverName string, t *protogen.GeneratedFile, m *protoge if v, ok := out.(*`, m.Output.GoIdent.GoName, `); ok { resp.Success(v) } else { - resp.Error(xerrors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `")) + resp.Error(errors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `")) } } }`) diff --git a/encoding/encoding.go b/encoding/encoding.go new file mode 100644 index 0000000..e26013d --- /dev/null +++ b/encoding/encoding.go @@ -0,0 +1,25 @@ +package encoding + +import "strings" + +type Codec interface { + Marshal(any) ([]byte, error) + Unmarshal([]byte, any) error + Name() string +} + +var registeredCodecs = make(map[string]Codec) + +func RegisterCodec(codec Codec) { + if codec == nil { + panic("cannot register a nil Codec") + } + if codec.Name() == "" { + panic("cannot register Codec with empty string result for Name()") + } + contentSubType := strings.ToLower(codec.Name()) + registeredCodecs[contentSubType] = codec +} +func GetCodec(contentSubType string) Codec { + return registeredCodecs[contentSubType] +} diff --git a/encoding/form/form.go b/encoding/form/form.go new file mode 100644 index 0000000..1058037 --- /dev/null +++ b/encoding/form/form.go @@ -0,0 +1,80 @@ +package form + +import ( + "git.diulo.com/mogfee/kit/encoding" + "github.com/go-playground/form" + "google.golang.org/protobuf/proto" + "net/url" + "reflect" +) + +const ( + Name = "x-www-form-urlencoded" + nullStr = "null" +) + +var ( + encoder = form.NewEncoder() + decoder = form.NewDecoder() +) + +func init() { + decoder.SetTagName("json") + encoder.SetTagName("json") + encoding.RegisterCodec(codec{ + encoder: encoder, + decoder: decoder, + }) +} + +type codec struct { + encoder *form.Encoder + decoder *form.Decoder +} + +func (c codec) Marshal(v any) ([]byte, error) { + var vs url.Values + var err error + if m, ok := v.(proto.Message); ok { + vs, err = EncodeValues(m) + if err != nil { + return nil, err + } + } else { + vs, err = c.encoder.Encode(v) + if err != nil { + return nil, err + } + } + for k, v := range vs { + if len(v) == 0 { + delete(vs, k) + } + } + return []byte(vs.Encode()), err +} + +func (c codec) Unmarshal(data []byte, v any) error { + vs, err := url.ParseQuery(string(data)) + if err != nil { + return err + } + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Ptr { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + if m, ok := v.(proto.Message); ok { + return DecodeValues(m, vs) + } + if m, ok := rv.Interface().(proto.Message); ok { + return DecodeValues(m, vs) + } + return c.decoder.Decode(v, vs) +} + +func (c codec) Name() string { + return Name +} diff --git a/xerrors/errors.go b/errors/errors.go similarity index 98% rename from xerrors/errors.go rename to errors/errors.go index 2690887..9d5fe1a 100644 --- a/xerrors/errors.go +++ b/errors/errors.go @@ -1,4 +1,4 @@ -package xerrors +package errors import ( "errors" diff --git a/xerrors/types.go b/errors/types.go similarity index 99% rename from xerrors/types.go rename to errors/types.go index 503cea1..e3d033b 100644 --- a/xerrors/types.go +++ b/errors/types.go @@ -1,5 +1,5 @@ // nolint:gomnd -package xerrors +package errors // BadRequest new BadRequest error that is mapped to a 400 response. func BadRequest(reason, message string) *Error { diff --git a/example/service/service.go b/example/service/service.go index 8416d1a..931e00a 100644 --- a/example/service/service.go +++ b/example/service/service.go @@ -19,7 +19,7 @@ func (*UserService) Login(ctx context.Context, req *user.LoginRequest) (*user.Lo } func (*UserService) List(ctx context.Context, req *user.LoginRequest) (*user.LoginResponse, error) { //fmt.Println(ctx.Value("userId")) - //return nil, errors.Wrap(xerrors.InternalServer("InternalServer", "B"), "") + //return nil, errors.Wrap(errors.InternalServer("InternalServer", "B"), "") //b, _ := json.Marshal(req) return &user.LoginResponse{Token: "123123"}, nil diff --git a/go.mod b/go.mod index 5be21e8..a9a10bf 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module git.diulo.com/mogfee/protoc-gen-kit +module git.diulo.com/mogfee/kit go 1.20 @@ -23,6 +23,7 @@ require ( github.com/goccy/go-json v0.10.0 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/uuid v1.3.0 // indirect + github.com/gorilla/mux v1.8.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.2 // indirect diff --git a/go.sum b/go.sum index bf47404..458d55a 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/internal/httputil/http.go b/internal/httputil/http.go new file mode 100644 index 0000000..e84453c --- /dev/null +++ b/internal/httputil/http.go @@ -0,0 +1,26 @@ +package httputil + +import "strings" + +const ( + baseContentType = "application" +) + +func ContentType(subtype string) string { + return strings.Join([]string{baseContentType, subtype}, "/") +} + +func ContentSubtype(contentType string) string { + left := strings.Index(contentType, "/") + if left == -1 { + return "" + } + right := strings.Index(contentType, ";") + if right == -1 { + right = len(contentType) + } + if right < left { + return "" + } + return contentType[left+1 : right] +} diff --git a/internal/matcher/matcher.go b/internal/matcher/matcher.go new file mode 100644 index 0000000..48e1dfb --- /dev/null +++ b/internal/matcher/matcher.go @@ -0,0 +1,56 @@ +package matcher + +import ( + "git.diulo.com/mogfee/kit/middleware" + "sort" + "strings" +) + +type Matcher interface { + Use(ms ...middleware.Middleware) + Add(selector string, ms ...middleware.Middleware) + Match(operation string) []middleware.Middleware +} + +func New() Matcher { + return &matcher{ + matchs: make(map[string][]middleware.Middleware), + } +} + +type matcher struct { + prefix []string + defaults []middleware.Middleware + matchs map[string][]middleware.Middleware +} + +func (m *matcher) Use(ms ...middleware.Middleware) { + m.defaults = ms +} + +func (m *matcher) Add(selector string, ms ...middleware.Middleware) { + if strings.HasSuffix(selector, "*") { + selector = strings.TrimSuffix(selector, "*") + m.prefix = append(m.prefix, selector) + sort.Slice(m.prefix, func(i, j int) bool { + return m.prefix[i] > m.prefix[j] + }) + } + m.matchs[selector] = ms +} + +func (m *matcher) Match(operation string) []middleware.Middleware { + ms := make([]middleware.Middleware, 0, len(m.defaults)) + if len(m.defaults) > 0 { + ms = append(ms, m.defaults...) + } + if next, ok := m.matchs[operation]; ok { + ms = append(ms, next...) + } + for _, prefix := range m.prefix { + if strings.HasPrefix(operation, prefix) { + return append(ms, m.matchs[prefix]...) + } + } + return ms +} diff --git a/options.go b/options.go index 9817831..c124b4b 100644 --- a/options.go +++ b/options.go @@ -2,7 +2,9 @@ package protoc_gen_kit import ( "context" - "git.diulo.com/mogfee/protoc-gen-kit/log" + "git.diulo.com/mogfee/kit/log" + "git.diulo.com/mogfee/kit/registry" + "git.diulo.com/mogfee/kit/transport" "net/url" "os" "time" diff --git a/registry/registry.go b/registry/registry.go new file mode 100644 index 0000000..f959b0d --- /dev/null +++ b/registry/registry.go @@ -0,0 +1,65 @@ +package registry + +import ( + "context" + "fmt" + "sort" +) + +type Registrar interface { + Register(ctx context.Context, service *ServiceInstance) error + Deregister(ctx context.Context, service *ServiceInstance) error +} +type Discovery interface { + GetService(ctx context.Context, serverName string) ([]ServiceInstance, error) + Watch(ctx context.Context, serviceName string) (Watch, error) +} +type Watch interface { + Next() ([]*ServiceInstance, error) + Stop() error +} + +type ServiceInstance struct { + ID string `json:"id"` + Name string `json:"name"` + Version string `json:"version"` + Metadata map[string]string `json:"metadata"` + Endpoints []string `json:"endpoints"` +} + +func (i *ServiceInstance) String() string { + return fmt.Sprintf("%s-%s", i.Name, i.ID) +} +func (i *ServiceInstance) Equal(o any) bool { + if i == nil && o == nil { + return true + } + if i == nil || o == nil { + return false + } + t, ok := o.(*ServiceInstance) + if !ok { + return false + } + if len(i.Endpoints) != len(t.Endpoints) { + return false + } + + sort.Strings(i.Endpoints) + sort.Strings(t.Endpoints) + + for j := 0; j < len(i.Endpoints); j++ { + if i.Endpoints[j] != t.Endpoints[j] { + return false + } + } + if len(i.Metadata) != len(t.Metadata) { + return false + } + for k, v := range i.Metadata { + if v != t.Metadata[k] { + return false + } + } + return i.ID == t.ID && i.Name != t.Name && i.Version != t.Version +} diff --git a/test/main.go b/test/main.go new file mode 100644 index 0000000..0f1e71f --- /dev/null +++ b/test/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + protoc_gen_kit "git.diulo.com/mogfee/protoc-gen-kit" +) + +func main() { + app := protoc_gen_kit.New( + protoc_gen_kit.Name("user-server"), + protoc_gen_kit.Server()) + fmt.Println("run start") + app.Run() + fmt.Println("run end") + app.Stop() +} diff --git a/transport/http/binding/bind.go b/transport/http/binding/bind.go new file mode 100644 index 0000000..796f442 --- /dev/null +++ b/transport/http/binding/bind.go @@ -0,0 +1,25 @@ +package binding + +import ( + "git.diulo.com/mogfee/kit/encoding" + "git.diulo.com/mogfee/kit/encoding/form" + "git.diulo.com/mogfee/kit/errors" + "net/http" + "net/url" +) + +func BindQuery(vars url.Values, target any) error { + if err := encoding.GetCodec(form.Name).Unmarshal([]byte(vars.Encode()), target); err != nil { + return errors.BadRequest("CODEC", err.Error()) + } + return nil +} +func BindForm(req *http.Request, target any) error { + if err := req.ParseForm(); err != nil { + return err + } + if err := encoding.GetCodec(form.Name).Unmarshal([]byte(req.Form.Encode()), target); err != nil { + return errors.BadRequest("CODEC", err.Error()) + } + return nil +} diff --git a/transport/http/codec.go b/transport/http/codec.go new file mode 100644 index 0000000..f89a29b --- /dev/null +++ b/transport/http/codec.go @@ -0,0 +1,95 @@ +package http + +import ( + "bytes" + "fmt" + "git.diulo.com/mogfee/kit/encoding" + "git.diulo.com/mogfee/kit/errors" + "git.diulo.com/mogfee/kit/internal/httputil" + "git.diulo.com/mogfee/kit/transport/http/binding" + "github.com/gorilla/mux" + "io" + "net/http" + "net/url" +) + +const SupportPackageIsVersion1 = true + +type Redirector interface { + Redirect() (string, int) +} +type Request = http.Request +type ResponseWriter = http.ResponseWriter +type Flusher = http.Flusher +type DecodeRequestFunc func(*http.Request, any) error +type EncodeResponseFunc func(http.ResponseWriter, *http.Request, any) error +type EncodeErrorFunc func(w http.ResponseWriter, r *http.Request, err error) + +func DefaultRequestVars(r *http.Request, v any) error { + raws := mux.Vars(r) + vars := make(url.Values, len(raws)) + for k, v := range raws { + vars[k] = []string{v} + } + return binding.BindQuery(vars, v) +} +func DefaultRequestQuery(r *http.Request, v any) error { + return binding.BindQuery(r.URL.Query(), v) +} +func DefaultRequestDecoder(r *http.Request, v any) error { + codec, ok := CodeForRequest(r, "Content-type") + if !ok { + return errors.BadRequest("CODEC", fmt.Sprintf("unregister Content-Type: %s", codec)) + } + data, err := io.ReadAll(r.Body) + if err != nil { + return errors.BadRequest("CODEC", err.Error()) + } + r.Body = io.NopCloser(bytes.NewBuffer(data)) + if len(data) == 0 { + return nil + } + if err = codec.Unmarshal(data, v); err != nil { + return errors.BadRequest("CODEC", err.Error()) + } + return nil +} +func DefaultResponseEncoder(w http.ResponseWriter, r *http.Request, v any) error { + if v == nil { + return nil + } + if rd, ok := v.(Redirector); ok { + url, code := rd.Redirect() + http.Redirect(w, r, url, code) + return nil + } + codec, _ := CodeForRequest(r, "Accept") + data, err := codec.Marshal(v) + if err != nil { + return err + } + w.Header().Set("Content-type", httputil.ContentType(codec.Name())) + _, err = w.Write(data) + return err +} +func DefaultErrorEncoder(w http.ResponseWriter, r *http.Request, err error) { + se := errors.FromError(err) + codec, _ := CodeForRequest(r, "Accept") + body, err := codec.Marshal(se) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", httputil.ContentType(codec.Name())) + w.WriteHeader(int(se.Status)) + _, _ = w.Write(body) +} +func CodeForRequest(r *http.Request, name string) (encoding.Codec, bool) { + for _, accept := range r.Header[name] { + codec := encoding.GetCodec(httputil.ContentSubtype(accept)) + if codec != nil { + return codec, true + } + } + return encoding.GetCodec("json"), false +} diff --git a/transport/http/context.go b/transport/http/context.go new file mode 100644 index 0000000..329ae35 --- /dev/null +++ b/transport/http/context.go @@ -0,0 +1,201 @@ +package http + +import ( + "context" + "encoding/json" + "encoding/xml" + "git.diulo.com/mogfee/kit/middleware" + "git.diulo.com/mogfee/kit/transport" + "git.diulo.com/mogfee/kit/transport/http/binding" + "github.com/gorilla/mux" + "io" + "net/http" + "net/url" + "time" +) + +type Context interface { + context.Context + Vars() url.Values + Query() url.Values + Form() url.Values + Header() http.Header + Request() *http.Request + Response() http.ResponseWriter + Middleware(middleware.Handler) middleware.Handler + Bind(interface{}) error + BindVars(interface{}) error + BindQuery(interface{}) error + BindForm(interface{}) error + Returns(interface{}, error) error + Result(int, interface{}) error + JSON(int, interface{}) error + XML(int, interface{}) error + String(int, string) error + Blob(int, string, []byte) error + Stream(int, string, io.Reader) error + Reset(http.ResponseWriter, *http.Request) +} +type responseWriter struct { + code int + w http.ResponseWriter +} + +func (r *responseWriter) Header() http.Header { + return r.w.Header() +} + +func (r *responseWriter) Write(bytes []byte) (int, error) { + r.w.WriteHeader(r.code) + return r.w.Write(bytes) +} + +func (r *responseWriter) WriteHeader(statusCode int) { + r.code = statusCode +} + +func (r *responseWriter) reset(res http.ResponseWriter) { + r.w = res + r.code = http.StatusOK +} + +type wrapper struct { + router *Router + req *http.Request + res http.ResponseWriter + w responseWriter +} + +func (w *wrapper) Deadline() (deadline time.Time, ok bool) { + if w.req == nil { + return time.Time{}, false + } + return w.req.Context().Deadline() +} + +func (w *wrapper) Done() <-chan struct{} { + if w.req == nil { + return nil + } + return w.req.Context().Done() +} + +func (w *wrapper) Err() error { + if w.req == nil { + return context.Canceled + } + return w.req.Context().Err() +} + +func (w *wrapper) Value(key any) any { + if w.req == nil { + return nil + } + return w.req.Context().Value(key) +} + +func (w *wrapper) Vars() url.Values { + raws := mux.Vars(w.req) + vars := make(url.Values, len(raws)) + for k, v := range raws { + vars[k] = []string{v} + } + return vars +} + +func (w *wrapper) Query() url.Values { + return w.req.URL.Query() +} + +func (w *wrapper) Form() url.Values { + if err := w.req.ParseForm(); err != nil { + return url.Values{} + } + return w.req.Form +} + +func (w *wrapper) Header() http.Header { + return w.req.Header +} + +func (w *wrapper) Request() *http.Request { + return w.req +} + +func (w *wrapper) Response() http.ResponseWriter { + return w.res +} + +func (w *wrapper) Middleware(handler middleware.Handler) middleware.Handler { + if tr, ok := transport.FromServerContext(w.req.Context()); ok { + return middleware.Chain(w.router.srv.middleware.Match(tr.Operation())...)(handler) + } + return middleware.Chain(w.router.srv.middleware.Match(w.req.URL.Path)...)(handler) +} + +func (w *wrapper) Bind(i any) error { + return w.router.srv.decBody(w.req, i) +} + +func (w *wrapper) BindVars(i interface{}) error { + return w.router.srv.decVars(w.req, i) +} + +func (w *wrapper) BindQuery(i interface{}) error { + return w.router.srv.decQuery(w.req, i) +} + +func (w *wrapper) BindForm(i interface{}) error { + return binding.BindForm(w.req, i) +} + +func (w *wrapper) Returns(i interface{}, err error) error { + if err != nil { + return err + } + return w.router.srv.enc(&w.w, w.req, i) +} + +func (w *wrapper) Result(i int, i2 interface{}) error { + w.w.WriteHeader(i) + return w.router.srv.enc(&w.w, w.req, i2) +} + +func (w *wrapper) JSON(i int, i2 interface{}) error { + w.res.Header().Set("Content-Type", "application/json") + w.res.WriteHeader(i) + return json.NewEncoder(w.res).Encode(i2) +} + +func (w *wrapper) XML(i int, i2 interface{}) error { + w.res.Header().Set("Content-Type", "application/xml") + w.res.WriteHeader(i) + return xml.NewEncoder(w.res).Encode(i2) +} + +func (w *wrapper) String(i int, s string) error { + w.res.Header().Set("Content-Type", "text/plain") + w.res.WriteHeader(i) + _, err := w.res.Write([]byte(s)) + return err +} + +func (w *wrapper) Blob(i int, s string, bytes []byte) error { + w.res.Header().Set("Content-Type", s) + w.res.WriteHeader(i) + _, err := w.res.Write(bytes) + return err +} + +func (w *wrapper) Stream(i int, s string, reader io.Reader) error { + w.res.Header().Set("Content-Type", s) + w.res.WriteHeader(i) + _, err := io.Copy(w.res, reader) + return err +} + +func (w *wrapper) Reset(writer http.ResponseWriter, request *http.Request) { + w.w.reset(writer) + w.res = writer + w.req = request +} diff --git a/transport/http/filter.go b/transport/http/filter.go new file mode 100644 index 0000000..f4c1d29 --- /dev/null +++ b/transport/http/filter.go @@ -0,0 +1,14 @@ +package http + +import "net/http" + +type FilterFunc func(http.Handler) http.Handler + +func FilterChain(filters ...FilterFunc) FilterFunc { + return func(next http.Handler) http.Handler { + for i := len(filters) - 1; i >= 0; i-- { + next = filters[i](next) + } + return next + } +} diff --git a/transport/http/router.go b/transport/http/router.go new file mode 100644 index 0000000..bc0a803 --- /dev/null +++ b/transport/http/router.go @@ -0,0 +1,59 @@ +package http + +import ( + "net/http" + "path" + "sync" +) + +type WalkRoutFunc func(RouteInfo) error + +type RouteInfo struct { + Path string + Method string +} +type HandlerFunc func(Context) error +type Router struct { + prefix string + pool sync.Pool + srv *Server + filters []FilterFunc +} + +func newRouter(prefix string, srv *Server, filters ...FilterFunc) *Router { + r := &Router{ + prefix: prefix, + srv: srv, + filters: filters, + } + r.pool.New = func() any { + return &wrapper{router: r} + } + return r +} +func (r *Router) Group(prefix string, filters ...FilterFunc) *Router { + var newFilters []FilterFunc + newFilters = append(newFilters, r.filters...) + newFilters = append(newFilters, filters...) + return newRouter(path.Join(r.prefix, prefix), r.srv, newFilters...) +} +func (r *Router) Handle(method string, relativePath string, h HandlerFunc, filters ...FilterFunc) { + next := http.Handler(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + ctx := r.pool.Get().(Context) + ctx.Reset(res, req) + if err := h(ctx); err != nil { + r.srv.ene(res, req, err) + } + ctx.Reset(nil, nil) + r.pool.Put(ctx) + })) + next = FilterChain(filters...)(next) + next = FilterChain(r.filters...)(next) + r.srv.router.Handle(path.Join(r.prefix, relativePath), next).Methods(method) +} +func (r *Router) GET(path string, h HandlerFunc, m ...FilterFunc) { + r.Handle(http.MethodGet, path, h, m...) +} +func (r *Router) POST(path string, h HandlerFunc, m ...FilterFunc) { + r.Handle(http.MethodPost, path, h, m...) +} diff --git a/transport/http/server.go b/transport/http/server.go new file mode 100644 index 0000000..2b4e6ef --- /dev/null +++ b/transport/http/server.go @@ -0,0 +1,263 @@ +package http + +import ( + "context" + "crypto/tls" + "git.diulo.com/mogfee/kit/internal/matcher" + "git.diulo.com/mogfee/kit/middleware" + "git.diulo.com/mogfee/kit/transport" + "github.com/gorilla/mux" + "net" + "net/http" + "net/url" + "time" +) + +var ( + _ transport.Server = (*Server)(nil) + _ transport.Endpointer = (*Server)(nil) + _ http.Handler = (*Server)(nil) +) + +type ServerOption func(server *Server) + +// Network with server network. +func Network(network string) ServerOption { + return func(s *Server) { + s.network = network + } +} + +// Address with server address. +func Address(addr string) ServerOption { + return func(s *Server) { + s.address = addr + } +} + +// Timeout with server timeout. +func Timeout(timeout time.Duration) ServerOption { + return func(s *Server) { + s.timeout = timeout + } +} + +// Logger with server logger. +// Deprecated: use global logger instead. +func Logger(logger log.Logger) ServerOption { + return func(s *Server) {} +} + +// Middleware with service middleware option. +func Middleware(m ...middleware.Middleware) ServerOption { + return func(o *Server) { + o.middleware.Use(m...) + } +} + +// Filter with HTTP middleware option. +func Filter(filters ...FilterFunc) ServerOption { + return func(o *Server) { + o.filters = filters + } +} + +// RequestVarsDecoder with request decoder. +func RequestVarsDecoder(dec DecodeRequestFunc) ServerOption { + return func(o *Server) { + o.decVars = dec + } +} + +// RequestQueryDecoder with request decoder. +func RequestQueryDecoder(dec DecodeRequestFunc) ServerOption { + return func(o *Server) { + o.decQuery = dec + } +} + +// RequestDecoder with request decoder. +func RequestDecoder(dec DecodeRequestFunc) ServerOption { + return func(o *Server) { + o.decBody = dec + } +} + +// ResponseEncoder with response encoder. +func ResponseEncoder(en EncodeResponseFunc) ServerOption { + return func(o *Server) { + o.enc = en + } +} + +// ErrorEncoder with error encoder. +func ErrorEncoder(en EncodeErrorFunc) ServerOption { + return func(o *Server) { + o.ene = en + } +} + +// TLSConfig with TLS config. +func TLSConfig(c *tls.Config) ServerOption { + return func(o *Server) { + o.tlsConf = c + } +} + +// StrictSlash is with mux's StrictSlash +// If true, when the path pattern is "/path/", accessing "/path" will +// redirect to the former and vice versa. +func StrictSlash(strictSlash bool) ServerOption { + return func(o *Server) { + o.strictSlash = strictSlash + } +} + +// Listener with server lis +func Listener(lis net.Listener) ServerOption { + return func(s *Server) { + s.lis = lis + } +} + +// PathPrefix with mux's PathPrefix, router will replaced by a subrouter that start with prefix. +func PathPrefix(prefix string) ServerOption { + return func(s *Server) { + s.router = s.router.PathPrefix(prefix).Subrouter() + } +} + +type Server struct { + *http.Server + lis net.Listener + tlsConf *tls.Config + endpoint *url.URL + err error + network string + address string + timeout time.Duration + filters []FilterFunc + middleware matcher.Matcher + decVars DecodeRequestFunc + decQuery DecodeRequestFunc + decBody DecodeRequestFunc + enc EncodeResponseFunc + ene EncodeErrorFunc + strictSlash bool + router *mux.Router +} + +func NewServer(opts ...ServerOption) *Server { + srv := &Server{ + network: "tcp", + address: ":0", + timeout: 1 * time.Second, + middleware: matcher.New(), + decVars: DefaultRequestVars, + decQuery: DefaultRequestQuery, + decBody: DefaultRequestDecoder, + strictSlash: true, + router: mux.NewRouter(), + } + for _, o := range opts { + o(srv) + } + + srv.router.StrictSlash(srv.strictSlash) + srv.router.NotFoundHandler = http.DefaultServeMux + srv.router.MethodNotAllowedHandler = http.DefaultServeMux + srv.router.Use(srv.filter()) + srv.Server = &http.Server{ + Handler: FilterChain(srv.filters...)(srv.router), + TLSConfig: srv.tlsConf, + } + return srv +} +func (s *Server) Use(selector string, m ...middleware.Middleware) { + s.middleware.Add(selector, m...) +} +func (s *Server) WalkRoute(fn WalkRoutFunc) error { + return s.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { + methods, err := route.GetMethods() + if err != nil { + return err + } + path, err := route.GetPathTemplate() + if err != nil { + return err + } + for _, method := range methods { + if err = fn(RouteInfo{ + Method: method, + Path: path, + }); err != nil { + return err + } + } + return nil + }) +} +func (s *Server) Kind() transport.Kind { + return transport.htt +} +func (s *Server) Route(prefix string, filters ...FilterFunc) *Router { + return newRouter(prefix, s, filters...) +} +func (s *Server) Handle(path string, h http.Handler) { + s.router.Handle(path, h) +} +func (s *Server) HandlePrefix(prefix string, h http.Handler) { + s.router.PathPrefix(prefix).Handler(h) +} +func (s *Server) HandleFunc(path string, h http.HandlerFunc) { + s.router.HandleFunc(path, h) +} +func (s *Server) HandleHeader(key, val string, h http.Handler) { + s.router.Headers(key, val).Handler(h) +} +func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { + s.Handler.ServeHTTP(res, req) +} +func (s *Server) filter() mux.MiddlewareFunc { + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + var ( + ctx context.Context + cancel context.CancelFunc + ) + + if s.timeout > 0 { + ctx, cancel = context.WithTimeout(request.Context(), s.timeout) + } else { + ctx, cancel = context.WithCancel(request.Context()) + } + defer cancel() + + pathTemplate := request.URL.Path + if route := mux.CurrentRoute(request); route != nil { + pathTemplate, _ = route.GetPathTemplate() + } + tr := &Transport{} + }) + } +} + +func (s *Server) Endpoint() string { + //TODO implement me + panic("implement me") +} + +func (s *Server) Operation() string { + //TODO implement me + panic("implement me") +} + +func (s *Server) RequestHeader() transport.Header { + //TODO implement me + panic("implement me") +} + +func (s *Server) ReplyHeader() transport.Header { + //TODO implement me + panic("implement me") +} diff --git a/transport/http/transport.go b/transport/http/transport.go new file mode 100644 index 0000000..1a0654c --- /dev/null +++ b/transport/http/transport.go @@ -0,0 +1,65 @@ +package http + +import ( + "context" + "git.diulo.com/mogfee/kit/transport" + "net/http" +) + +type Transporter interface { + transport.Transporter + Request() *http.Request + PathTemplate() string +} + +type Transport struct { + endpoint string + operation string + reqHeader headerCarrier + replyHeader headerCarrier + request *http.Request + pathTemplate string +} + +func (t *Transport) Kind() transport.Kind { + return transport.KindHTTP +} + +func (t *Transport) Endpoint() string { + return t.endpoint +} + +func (t *Transport) Operation() string { + return t.operation +} + +func (t *Transport) RequestHeader() transport.Header { + return t.reqHeader +} + +func (t *Transport) ReplyHeader() transport.Header { + return t.replyHeader +} + +func (t *Transport) Request() *http.Request { + return t.request +} + +func (t *Transport) PathTemplate() string { + return t.pathTemplate +} +func SetOperation(ctx context.Context, op string) { + if tr, ok := transport.FromServerContext(ctx); ok { + if tr, ok := tr.(*Transport); ok { + tr.operation = op + } + } +} +func RequestFromServerContext(ctx context.Context) (*http.Request, bool) { + if tr, ok := transport.FromServerContext(ctx); ok { + if tr, ok := tr.(*Transport); ok { + return tr.request, true + } + } + return nil, false +} diff --git a/transport/transport.go b/transport/transport.go new file mode 100644 index 0000000..8a84181 --- /dev/null +++ b/transport/transport.go @@ -0,0 +1,58 @@ +package transport + +import ( + "context" + "net/url" +) + +type Server interface { + Start(context.Context) error + Stop(context.Context) error +} + +type Endpointer interface { + Endpoint() (*url.URL, error) +} +type Header interface { + Get(key string) string + Set(key string, value string) + Keys() []string +} +type Transporter interface { + Kind() Kind + Endpoint() string + Operation() string + RequestHeader() Header + ReplyHeader() Header +} +type Kind string + +func (k Kind) String() string { + return string(k) +} + +const ( + KindGRPC Kind = "grpc" + KindHTTP Kind = "http" +) + +type ( + serverTransportKey struct{} + clientTransportKey struct{} +) + +func NewServerContext(ctx context.Context, tr Transporter) context.Context { + return context.WithValue(ctx, serverTransportKey{}, tr) +} +func FromServerContext(ctx context.Context) (tr Transporter, ok bool) { + tr, ok = ctx.Value(serverTransportKey{}).(Transporter) + return +} +func NewClientContext(ctx context.Context, tr Transporter) context.Context { + return context.WithValue(ctx, clientTransportKey{}, tr) +} + +func FromClientContext(ctx context.Context) (tr Transporter, ok bool) { + tr, ok = ctx.Value(clientTransportKey{}).(Transporter) + return +}