mogfee 2 years ago
parent 2cd5b2ded0
commit 75c9c1c909
  1. 88
      app.go
  2. 8
      cmd/kit/main.go
  3. 25
      encoding/encoding.go
  4. 80
      encoding/form/form.go
  5. 2
      errors/errors.go
  6. 2
      errors/types.go
  7. 2
      example/service/service.go
  8. 3
      go.mod
  9. 2
      go.sum
  10. 26
      internal/httputil/http.go
  11. 56
      internal/matcher/matcher.go
  12. 4
      options.go
  13. 65
      registry/registry.go
  14. 16
      test/main.go
  15. 25
      transport/http/binding/bind.go
  16. 95
      transport/http/codec.go
  17. 201
      transport/http/context.go
  18. 14
      transport/http/filter.go
  19. 59
      transport/http/router.go
  20. 263
      transport/http/server.go
  21. 65
      transport/http/transport.go
  22. 58
      transport/transport.go

@ -2,10 +2,14 @@ package protoc_gen_kit
import ( import (
"context" "context"
"errors"
"git.diulo.com/mogfee/protoc-gen-kit/log" "git.diulo.com/mogfee/protoc-gen-kit/log"
"git.diulo.com/mogfee/protoc-gen-kit/registry"
"git.diulo.com/mogfee/protoc-gen-kit/transport"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"os" "os"
"os/signal"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@ -90,7 +94,7 @@ func (a *App) Run() error {
srv := srv srv := srv
eg.Go(func() error { eg.Go(func() error {
<-ctx.Done() <-ctx.Done()
stopCtx, cancel := context.WithTimeout(NewContext(a.opts.ctx, a)) stopCtx, cancel := context.WithTimeout(NewContext(a.opts.ctx, a), a.opts.stopTimeout)
defer cancel() defer cancel()
return srv.Stop(stopCtx) return srv.Stop(stopCtx)
}) })
@ -101,5 +105,87 @@ func (a *App) Run() error {
}) })
} }
wg.Wait() wg.Wait()
if a.opts.registrar != nil {
rctx, rcancel := context.WithTimeout(ctx, a.opts.registrarTimeout)
defer rcancel()
if err = a.opts.registrar.Register(rctx, instance); err != nil {
return err
}
}
for _, fn := range a.opts.afterStart {
if err = fn(sctx); err != nil {
return err
}
}
c := make(chan os.Signal, 1)
signal.Notify(c, a.opts.sigs...)
eg.Go(func() error {
select {
case <-ctx.Done():
return nil
case <-c:
return a.Stop()
}
})
if err = eg.Wait(); err != nil && !errors.Is(err, context.Canceled) {
for _, fn := range a.opts.afterStop {
err = fn(sctx)
}
}
return err
}
func (a *App) Stop() (err error) {
sctx := NewContext(a.ctx, a)
for _, fn := range a.opts.beforeStop {
err = fn(sctx)
}
a.mu.Lock()
instance := a.instance
a.mu.Unlock()
if a.opts.registrar != nil && instance != nil {
ctx, cancel := context.WithTimeout(NewContext(a.ctx, a), a.opts.stopTimeout)
defer cancel()
if err = a.opts.registrar.Deregister(ctx, instance); err != nil {
return err
}
}
if a.cancel != nil {
a.cancel()
}
return err
}
func (a *App) buildInstance() (*registry.ServiceInstance, error) {
endpoints := make([]string, 0, len(a.opts.endpoints))
for _, e := range a.opts.endpoints {
endpoints = append(endpoints, e.String())
}
if len(endpoints) == 0 {
for _, srv := range a.opts.servers {
if r, ok := srv.(transport.Endpointer); ok {
e, err := r.Endpoint()
if err != nil {
return nil, err
}
endpoints = append(endpoints, e.String())
}
}
}
return &registry.ServiceInstance{
ID: a.opts.id,
Name: a.opts.name,
Version: a.opts.version,
Metadata: a.opts.metadata,
Endpoints: endpoints,
}, nil
}
type appKey struct {
}
func NewContext(ctx context.Context, s AppInfo) context.Context {
return context.WithValue(ctx, appKey{}, s)
}
func FromContext(ctx context.Context) (s AppInfo, ok bool) {
s, ok = ctx.Value(appKey{}).(AppInfo)
return
} }

@ -26,7 +26,7 @@ func (u *Kit) Generate(plugin *protogen.Plugin) error {
u.addImports("context") u.addImports("context")
u.addImports("git.diulo.com/mogfee/protoc-gen-kit/middleware") u.addImports("git.diulo.com/mogfee/protoc-gen-kit/middleware")
u.addImports("git.diulo.com/mogfee/protoc-gen-kit/pkg/response") u.addImports("git.diulo.com/mogfee/protoc-gen-kit/pkg/response")
u.addImports("git.diulo.com/mogfee/protoc-gen-kit/pkg/xerrors") u.addImports("git.diulo.com/mogfee/protoc-gen-kit/pkg/errors")
u.addImports("github.com/gin-gonic/gin") u.addImports("github.com/gin-gonic/gin")
for _, f := range plugin.Files { for _, f := range plugin.Files {
if len(f.Services) == 0 { if len(f.Services) == 0 {
@ -94,7 +94,7 @@ func (u *Kit) genGet(serverName string, t *protogen.GeneratedFile, m *protogen.M
if v, ok := out.(*`, m.Output.GoIdent.GoName, `); ok { if v, ok := out.(*`, m.Output.GoIdent.GoName, `); ok {
resp.Success(v) resp.Success(v)
} else { } else {
resp.Error(xerrors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `")) resp.Error(errors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `"))
} }
} }
}`) }`)
@ -119,7 +119,7 @@ func (u *Kit) genPost(serverName string, t *protogen.GeneratedFile, m *protogen.
if v, ok := out.(*`, m.Output.GoIdent.GoName, `); ok { if v, ok := out.(*`, m.Output.GoIdent.GoName, `); ok {
resp.Success(v) resp.Success(v)
} else { } else {
resp.Error(xerrors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `")) resp.Error(errors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `"))
} }
} }
}`) }`)
@ -144,7 +144,7 @@ func (u *Kit) genDelete(serverName string, t *protogen.GeneratedFile, m *protoge
if v, ok := out.(*`, m.Output.GoIdent.GoName, `); ok { if v, ok := out.(*`, m.Output.GoIdent.GoName, `); ok {
resp.Success(v) resp.Success(v)
} else { } else {
resp.Error(xerrors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `")) resp.Error(errors.InternalServer("RESULT_TYPE_ERROR", "`, m.Output.GoIdent.GoName, `"))
} }
} }
}`) }`)

@ -0,0 +1,25 @@
package encoding
import "strings"
type Codec interface {
Marshal(any) ([]byte, error)
Unmarshal([]byte, any) error
Name() string
}
var registeredCodecs = make(map[string]Codec)
func RegisterCodec(codec Codec) {
if codec == nil {
panic("cannot register a nil Codec")
}
if codec.Name() == "" {
panic("cannot register Codec with empty string result for Name()")
}
contentSubType := strings.ToLower(codec.Name())
registeredCodecs[contentSubType] = codec
}
func GetCodec(contentSubType string) Codec {
return registeredCodecs[contentSubType]
}

@ -0,0 +1,80 @@
package form
import (
"git.diulo.com/mogfee/kit/encoding"
"github.com/go-playground/form"
"google.golang.org/protobuf/proto"
"net/url"
"reflect"
)
const (
Name = "x-www-form-urlencoded"
nullStr = "null"
)
var (
encoder = form.NewEncoder()
decoder = form.NewDecoder()
)
func init() {
decoder.SetTagName("json")
encoder.SetTagName("json")
encoding.RegisterCodec(codec{
encoder: encoder,
decoder: decoder,
})
}
type codec struct {
encoder *form.Encoder
decoder *form.Decoder
}
func (c codec) Marshal(v any) ([]byte, error) {
var vs url.Values
var err error
if m, ok := v.(proto.Message); ok {
vs, err = EncodeValues(m)
if err != nil {
return nil, err
}
} else {
vs, err = c.encoder.Encode(v)
if err != nil {
return nil, err
}
}
for k, v := range vs {
if len(v) == 0 {
delete(vs, k)
}
}
return []byte(vs.Encode()), err
}
func (c codec) Unmarshal(data []byte, v any) error {
vs, err := url.ParseQuery(string(data))
if err != nil {
return err
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr {
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
rv = rv.Elem()
}
if m, ok := v.(proto.Message); ok {
return DecodeValues(m, vs)
}
if m, ok := rv.Interface().(proto.Message); ok {
return DecodeValues(m, vs)
}
return c.decoder.Decode(v, vs)
}
func (c codec) Name() string {
return Name
}

@ -1,4 +1,4 @@
package xerrors package errors
import ( import (
"errors" "errors"

@ -1,5 +1,5 @@
// nolint:gomnd // nolint:gomnd
package xerrors package errors
// BadRequest new BadRequest error that is mapped to a 400 response. // BadRequest new BadRequest error that is mapped to a 400 response.
func BadRequest(reason, message string) *Error { func BadRequest(reason, message string) *Error {

@ -19,7 +19,7 @@ func (*UserService) Login(ctx context.Context, req *user.LoginRequest) (*user.Lo
} }
func (*UserService) List(ctx context.Context, req *user.LoginRequest) (*user.LoginResponse, error) { func (*UserService) List(ctx context.Context, req *user.LoginRequest) (*user.LoginResponse, error) {
//fmt.Println(ctx.Value("userId")) //fmt.Println(ctx.Value("userId"))
//return nil, errors.Wrap(xerrors.InternalServer("InternalServer", "B"), "") //return nil, errors.Wrap(errors.InternalServer("InternalServer", "B"), "")
//b, _ := json.Marshal(req) //b, _ := json.Marshal(req)
return &user.LoginResponse{Token: "123123"}, nil return &user.LoginResponse{Token: "123123"}, nil

@ -1,4 +1,4 @@
module git.diulo.com/mogfee/protoc-gen-kit module git.diulo.com/mogfee/kit
go 1.20 go 1.20
@ -23,6 +23,7 @@ require (
github.com/goccy/go-json v0.10.0 // indirect github.com/goccy/go-json v0.10.0 // indirect
github.com/golang/snappy v0.0.4 // indirect github.com/golang/snappy v0.0.4 // indirect
github.com/google/uuid v1.3.0 // indirect github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/mux v1.8.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.2 // indirect github.com/leodido/go-urn v1.2.2 // indirect

@ -35,6 +35,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=

@ -0,0 +1,26 @@
package httputil
import "strings"
const (
baseContentType = "application"
)
func ContentType(subtype string) string {
return strings.Join([]string{baseContentType, subtype}, "/")
}
func ContentSubtype(contentType string) string {
left := strings.Index(contentType, "/")
if left == -1 {
return ""
}
right := strings.Index(contentType, ";")
if right == -1 {
right = len(contentType)
}
if right < left {
return ""
}
return contentType[left+1 : right]
}

@ -0,0 +1,56 @@
package matcher
import (
"git.diulo.com/mogfee/kit/middleware"
"sort"
"strings"
)
type Matcher interface {
Use(ms ...middleware.Middleware)
Add(selector string, ms ...middleware.Middleware)
Match(operation string) []middleware.Middleware
}
func New() Matcher {
return &matcher{
matchs: make(map[string][]middleware.Middleware),
}
}
type matcher struct {
prefix []string
defaults []middleware.Middleware
matchs map[string][]middleware.Middleware
}
func (m *matcher) Use(ms ...middleware.Middleware) {
m.defaults = ms
}
func (m *matcher) Add(selector string, ms ...middleware.Middleware) {
if strings.HasSuffix(selector, "*") {
selector = strings.TrimSuffix(selector, "*")
m.prefix = append(m.prefix, selector)
sort.Slice(m.prefix, func(i, j int) bool {
return m.prefix[i] > m.prefix[j]
})
}
m.matchs[selector] = ms
}
func (m *matcher) Match(operation string) []middleware.Middleware {
ms := make([]middleware.Middleware, 0, len(m.defaults))
if len(m.defaults) > 0 {
ms = append(ms, m.defaults...)
}
if next, ok := m.matchs[operation]; ok {
ms = append(ms, next...)
}
for _, prefix := range m.prefix {
if strings.HasPrefix(operation, prefix) {
return append(ms, m.matchs[prefix]...)
}
}
return ms
}

@ -2,7 +2,9 @@ package protoc_gen_kit
import ( import (
"context" "context"
"git.diulo.com/mogfee/protoc-gen-kit/log" "git.diulo.com/mogfee/kit/log"
"git.diulo.com/mogfee/kit/registry"
"git.diulo.com/mogfee/kit/transport"
"net/url" "net/url"
"os" "os"
"time" "time"

@ -0,0 +1,65 @@
package registry
import (
"context"
"fmt"
"sort"
)
type Registrar interface {
Register(ctx context.Context, service *ServiceInstance) error
Deregister(ctx context.Context, service *ServiceInstance) error
}
type Discovery interface {
GetService(ctx context.Context, serverName string) ([]ServiceInstance, error)
Watch(ctx context.Context, serviceName string) (Watch, error)
}
type Watch interface {
Next() ([]*ServiceInstance, error)
Stop() error
}
type ServiceInstance struct {
ID string `json:"id"`
Name string `json:"name"`
Version string `json:"version"`
Metadata map[string]string `json:"metadata"`
Endpoints []string `json:"endpoints"`
}
func (i *ServiceInstance) String() string {
return fmt.Sprintf("%s-%s", i.Name, i.ID)
}
func (i *ServiceInstance) Equal(o any) bool {
if i == nil && o == nil {
return true
}
if i == nil || o == nil {
return false
}
t, ok := o.(*ServiceInstance)
if !ok {
return false
}
if len(i.Endpoints) != len(t.Endpoints) {
return false
}
sort.Strings(i.Endpoints)
sort.Strings(t.Endpoints)
for j := 0; j < len(i.Endpoints); j++ {
if i.Endpoints[j] != t.Endpoints[j] {
return false
}
}
if len(i.Metadata) != len(t.Metadata) {
return false
}
for k, v := range i.Metadata {
if v != t.Metadata[k] {
return false
}
}
return i.ID == t.ID && i.Name != t.Name && i.Version != t.Version
}

@ -0,0 +1,16 @@
package main
import (
"fmt"
protoc_gen_kit "git.diulo.com/mogfee/protoc-gen-kit"
)
func main() {
app := protoc_gen_kit.New(
protoc_gen_kit.Name("user-server"),
protoc_gen_kit.Server())
fmt.Println("run start")
app.Run()
fmt.Println("run end")
app.Stop()
}

@ -0,0 +1,25 @@
package binding
import (
"git.diulo.com/mogfee/kit/encoding"
"git.diulo.com/mogfee/kit/encoding/form"
"git.diulo.com/mogfee/kit/errors"
"net/http"
"net/url"
)
func BindQuery(vars url.Values, target any) error {
if err := encoding.GetCodec(form.Name).Unmarshal([]byte(vars.Encode()), target); err != nil {
return errors.BadRequest("CODEC", err.Error())
}
return nil
}
func BindForm(req *http.Request, target any) error {
if err := req.ParseForm(); err != nil {
return err
}
if err := encoding.GetCodec(form.Name).Unmarshal([]byte(req.Form.Encode()), target); err != nil {
return errors.BadRequest("CODEC", err.Error())
}
return nil
}

@ -0,0 +1,95 @@
package http
import (
"bytes"
"fmt"
"git.diulo.com/mogfee/kit/encoding"
"git.diulo.com/mogfee/kit/errors"
"git.diulo.com/mogfee/kit/internal/httputil"
"git.diulo.com/mogfee/kit/transport/http/binding"
"github.com/gorilla/mux"
"io"
"net/http"
"net/url"
)
const SupportPackageIsVersion1 = true
type Redirector interface {
Redirect() (string, int)
}
type Request = http.Request
type ResponseWriter = http.ResponseWriter
type Flusher = http.Flusher
type DecodeRequestFunc func(*http.Request, any) error
type EncodeResponseFunc func(http.ResponseWriter, *http.Request, any) error
type EncodeErrorFunc func(w http.ResponseWriter, r *http.Request, err error)
func DefaultRequestVars(r *http.Request, v any) error {
raws := mux.Vars(r)
vars := make(url.Values, len(raws))
for k, v := range raws {
vars[k] = []string{v}
}
return binding.BindQuery(vars, v)
}
func DefaultRequestQuery(r *http.Request, v any) error {
return binding.BindQuery(r.URL.Query(), v)
}
func DefaultRequestDecoder(r *http.Request, v any) error {
codec, ok := CodeForRequest(r, "Content-type")
if !ok {
return errors.BadRequest("CODEC", fmt.Sprintf("unregister Content-Type: %s", codec))
}
data, err := io.ReadAll(r.Body)
if err != nil {
return errors.BadRequest("CODEC", err.Error())
}
r.Body = io.NopCloser(bytes.NewBuffer(data))
if len(data) == 0 {
return nil
}
if err = codec.Unmarshal(data, v); err != nil {
return errors.BadRequest("CODEC", err.Error())
}
return nil
}
func DefaultResponseEncoder(w http.ResponseWriter, r *http.Request, v any) error {
if v == nil {
return nil
}
if rd, ok := v.(Redirector); ok {
url, code := rd.Redirect()
http.Redirect(w, r, url, code)
return nil
}
codec, _ := CodeForRequest(r, "Accept")
data, err := codec.Marshal(v)
if err != nil {
return err
}
w.Header().Set("Content-type", httputil.ContentType(codec.Name()))
_, err = w.Write(data)
return err
}
func DefaultErrorEncoder(w http.ResponseWriter, r *http.Request, err error) {
se := errors.FromError(err)
codec, _ := CodeForRequest(r, "Accept")
body, err := codec.Marshal(se)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", httputil.ContentType(codec.Name()))
w.WriteHeader(int(se.Status))
_, _ = w.Write(body)
}
func CodeForRequest(r *http.Request, name string) (encoding.Codec, bool) {
for _, accept := range r.Header[name] {
codec := encoding.GetCodec(httputil.ContentSubtype(accept))
if codec != nil {
return codec, true
}
}
return encoding.GetCodec("json"), false
}

@ -0,0 +1,201 @@
package http
import (
"context"
"encoding/json"
"encoding/xml"
"git.diulo.com/mogfee/kit/middleware"
"git.diulo.com/mogfee/kit/transport"
"git.diulo.com/mogfee/kit/transport/http/binding"
"github.com/gorilla/mux"
"io"
"net/http"
"net/url"
"time"
)
type Context interface {
context.Context
Vars() url.Values
Query() url.Values
Form() url.Values
Header() http.Header
Request() *http.Request
Response() http.ResponseWriter
Middleware(middleware.Handler) middleware.Handler
Bind(interface{}) error
BindVars(interface{}) error
BindQuery(interface{}) error
BindForm(interface{}) error
Returns(interface{}, error) error
Result(int, interface{}) error
JSON(int, interface{}) error
XML(int, interface{}) error
String(int, string) error
Blob(int, string, []byte) error
Stream(int, string, io.Reader) error
Reset(http.ResponseWriter, *http.Request)
}
type responseWriter struct {
code int
w http.ResponseWriter
}
func (r *responseWriter) Header() http.Header {
return r.w.Header()
}
func (r *responseWriter) Write(bytes []byte) (int, error) {
r.w.WriteHeader(r.code)
return r.w.Write(bytes)
}
func (r *responseWriter) WriteHeader(statusCode int) {
r.code = statusCode
}
func (r *responseWriter) reset(res http.ResponseWriter) {
r.w = res
r.code = http.StatusOK
}
type wrapper struct {
router *Router
req *http.Request
res http.ResponseWriter
w responseWriter
}
func (w *wrapper) Deadline() (deadline time.Time, ok bool) {
if w.req == nil {
return time.Time{}, false
}
return w.req.Context().Deadline()
}
func (w *wrapper) Done() <-chan struct{} {
if w.req == nil {
return nil
}
return w.req.Context().Done()
}
func (w *wrapper) Err() error {
if w.req == nil {
return context.Canceled
}
return w.req.Context().Err()
}
func (w *wrapper) Value(key any) any {
if w.req == nil {
return nil
}
return w.req.Context().Value(key)
}
func (w *wrapper) Vars() url.Values {
raws := mux.Vars(w.req)
vars := make(url.Values, len(raws))
for k, v := range raws {
vars[k] = []string{v}
}
return vars
}
func (w *wrapper) Query() url.Values {
return w.req.URL.Query()
}
func (w *wrapper) Form() url.Values {
if err := w.req.ParseForm(); err != nil {
return url.Values{}
}
return w.req.Form
}
func (w *wrapper) Header() http.Header {
return w.req.Header
}
func (w *wrapper) Request() *http.Request {
return w.req
}
func (w *wrapper) Response() http.ResponseWriter {
return w.res
}
func (w *wrapper) Middleware(handler middleware.Handler) middleware.Handler {
if tr, ok := transport.FromServerContext(w.req.Context()); ok {
return middleware.Chain(w.router.srv.middleware.Match(tr.Operation())...)(handler)
}
return middleware.Chain(w.router.srv.middleware.Match(w.req.URL.Path)...)(handler)
}
func (w *wrapper) Bind(i any) error {
return w.router.srv.decBody(w.req, i)
}
func (w *wrapper) BindVars(i interface{}) error {
return w.router.srv.decVars(w.req, i)
}
func (w *wrapper) BindQuery(i interface{}) error {
return w.router.srv.decQuery(w.req, i)
}
func (w *wrapper) BindForm(i interface{}) error {
return binding.BindForm(w.req, i)
}
func (w *wrapper) Returns(i interface{}, err error) error {
if err != nil {
return err
}
return w.router.srv.enc(&w.w, w.req, i)
}
func (w *wrapper) Result(i int, i2 interface{}) error {
w.w.WriteHeader(i)
return w.router.srv.enc(&w.w, w.req, i2)
}
func (w *wrapper) JSON(i int, i2 interface{}) error {
w.res.Header().Set("Content-Type", "application/json")
w.res.WriteHeader(i)
return json.NewEncoder(w.res).Encode(i2)
}
func (w *wrapper) XML(i int, i2 interface{}) error {
w.res.Header().Set("Content-Type", "application/xml")
w.res.WriteHeader(i)
return xml.NewEncoder(w.res).Encode(i2)
}
func (w *wrapper) String(i int, s string) error {
w.res.Header().Set("Content-Type", "text/plain")
w.res.WriteHeader(i)
_, err := w.res.Write([]byte(s))
return err
}
func (w *wrapper) Blob(i int, s string, bytes []byte) error {
w.res.Header().Set("Content-Type", s)
w.res.WriteHeader(i)
_, err := w.res.Write(bytes)
return err
}
func (w *wrapper) Stream(i int, s string, reader io.Reader) error {
w.res.Header().Set("Content-Type", s)
w.res.WriteHeader(i)
_, err := io.Copy(w.res, reader)
return err
}
func (w *wrapper) Reset(writer http.ResponseWriter, request *http.Request) {
w.w.reset(writer)
w.res = writer
w.req = request
}

@ -0,0 +1,14 @@
package http
import "net/http"
type FilterFunc func(http.Handler) http.Handler
func FilterChain(filters ...FilterFunc) FilterFunc {
return func(next http.Handler) http.Handler {
for i := len(filters) - 1; i >= 0; i-- {
next = filters[i](next)
}
return next
}
}

@ -0,0 +1,59 @@
package http
import (
"net/http"
"path"
"sync"
)
type WalkRoutFunc func(RouteInfo) error
type RouteInfo struct {
Path string
Method string
}
type HandlerFunc func(Context) error
type Router struct {
prefix string
pool sync.Pool
srv *Server
filters []FilterFunc
}
func newRouter(prefix string, srv *Server, filters ...FilterFunc) *Router {
r := &Router{
prefix: prefix,
srv: srv,
filters: filters,
}
r.pool.New = func() any {
return &wrapper{router: r}
}
return r
}
func (r *Router) Group(prefix string, filters ...FilterFunc) *Router {
var newFilters []FilterFunc
newFilters = append(newFilters, r.filters...)
newFilters = append(newFilters, filters...)
return newRouter(path.Join(r.prefix, prefix), r.srv, newFilters...)
}
func (r *Router) Handle(method string, relativePath string, h HandlerFunc, filters ...FilterFunc) {
next := http.Handler(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
ctx := r.pool.Get().(Context)
ctx.Reset(res, req)
if err := h(ctx); err != nil {
r.srv.ene(res, req, err)
}
ctx.Reset(nil, nil)
r.pool.Put(ctx)
}))
next = FilterChain(filters...)(next)
next = FilterChain(r.filters...)(next)
r.srv.router.Handle(path.Join(r.prefix, relativePath), next).Methods(method)
}
func (r *Router) GET(path string, h HandlerFunc, m ...FilterFunc) {
r.Handle(http.MethodGet, path, h, m...)
}
func (r *Router) POST(path string, h HandlerFunc, m ...FilterFunc) {
r.Handle(http.MethodPost, path, h, m...)
}

@ -0,0 +1,263 @@
package http
import (
"context"
"crypto/tls"
"git.diulo.com/mogfee/kit/internal/matcher"
"git.diulo.com/mogfee/kit/middleware"
"git.diulo.com/mogfee/kit/transport"
"github.com/gorilla/mux"
"net"
"net/http"
"net/url"
"time"
)
var (
_ transport.Server = (*Server)(nil)
_ transport.Endpointer = (*Server)(nil)
_ http.Handler = (*Server)(nil)
)
type ServerOption func(server *Server)
// Network with server network.
func Network(network string) ServerOption {
return func(s *Server) {
s.network = network
}
}
// Address with server address.
func Address(addr string) ServerOption {
return func(s *Server) {
s.address = addr
}
}
// Timeout with server timeout.
func Timeout(timeout time.Duration) ServerOption {
return func(s *Server) {
s.timeout = timeout
}
}
// Logger with server logger.
// Deprecated: use global logger instead.
func Logger(logger log.Logger) ServerOption {
return func(s *Server) {}
}
// Middleware with service middleware option.
func Middleware(m ...middleware.Middleware) ServerOption {
return func(o *Server) {
o.middleware.Use(m...)
}
}
// Filter with HTTP middleware option.
func Filter(filters ...FilterFunc) ServerOption {
return func(o *Server) {
o.filters = filters
}
}
// RequestVarsDecoder with request decoder.
func RequestVarsDecoder(dec DecodeRequestFunc) ServerOption {
return func(o *Server) {
o.decVars = dec
}
}
// RequestQueryDecoder with request decoder.
func RequestQueryDecoder(dec DecodeRequestFunc) ServerOption {
return func(o *Server) {
o.decQuery = dec
}
}
// RequestDecoder with request decoder.
func RequestDecoder(dec DecodeRequestFunc) ServerOption {
return func(o *Server) {
o.decBody = dec
}
}
// ResponseEncoder with response encoder.
func ResponseEncoder(en EncodeResponseFunc) ServerOption {
return func(o *Server) {
o.enc = en
}
}
// ErrorEncoder with error encoder.
func ErrorEncoder(en EncodeErrorFunc) ServerOption {
return func(o *Server) {
o.ene = en
}
}
// TLSConfig with TLS config.
func TLSConfig(c *tls.Config) ServerOption {
return func(o *Server) {
o.tlsConf = c
}
}
// StrictSlash is with mux's StrictSlash
// If true, when the path pattern is "/path/", accessing "/path" will
// redirect to the former and vice versa.
func StrictSlash(strictSlash bool) ServerOption {
return func(o *Server) {
o.strictSlash = strictSlash
}
}
// Listener with server lis
func Listener(lis net.Listener) ServerOption {
return func(s *Server) {
s.lis = lis
}
}
// PathPrefix with mux's PathPrefix, router will replaced by a subrouter that start with prefix.
func PathPrefix(prefix string) ServerOption {
return func(s *Server) {
s.router = s.router.PathPrefix(prefix).Subrouter()
}
}
type Server struct {
*http.Server
lis net.Listener
tlsConf *tls.Config
endpoint *url.URL
err error
network string
address string
timeout time.Duration
filters []FilterFunc
middleware matcher.Matcher
decVars DecodeRequestFunc
decQuery DecodeRequestFunc
decBody DecodeRequestFunc
enc EncodeResponseFunc
ene EncodeErrorFunc
strictSlash bool
router *mux.Router
}
func NewServer(opts ...ServerOption) *Server {
srv := &Server{
network: "tcp",
address: ":0",
timeout: 1 * time.Second,
middleware: matcher.New(),
decVars: DefaultRequestVars,
decQuery: DefaultRequestQuery,
decBody: DefaultRequestDecoder,
strictSlash: true,
router: mux.NewRouter(),
}
for _, o := range opts {
o(srv)
}
srv.router.StrictSlash(srv.strictSlash)
srv.router.NotFoundHandler = http.DefaultServeMux
srv.router.MethodNotAllowedHandler = http.DefaultServeMux
srv.router.Use(srv.filter())
srv.Server = &http.Server{
Handler: FilterChain(srv.filters...)(srv.router),
TLSConfig: srv.tlsConf,
}
return srv
}
func (s *Server) Use(selector string, m ...middleware.Middleware) {
s.middleware.Add(selector, m...)
}
func (s *Server) WalkRoute(fn WalkRoutFunc) error {
return s.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
methods, err := route.GetMethods()
if err != nil {
return err
}
path, err := route.GetPathTemplate()
if err != nil {
return err
}
for _, method := range methods {
if err = fn(RouteInfo{
Method: method,
Path: path,
}); err != nil {
return err
}
}
return nil
})
}
func (s *Server) Kind() transport.Kind {
return transport.htt
}
func (s *Server) Route(prefix string, filters ...FilterFunc) *Router {
return newRouter(prefix, s, filters...)
}
func (s *Server) Handle(path string, h http.Handler) {
s.router.Handle(path, h)
}
func (s *Server) HandlePrefix(prefix string, h http.Handler) {
s.router.PathPrefix(prefix).Handler(h)
}
func (s *Server) HandleFunc(path string, h http.HandlerFunc) {
s.router.HandleFunc(path, h)
}
func (s *Server) HandleHeader(key, val string, h http.Handler) {
s.router.Headers(key, val).Handler(h)
}
func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) {
s.Handler.ServeHTTP(res, req)
}
func (s *Server) filter() mux.MiddlewareFunc {
return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
var (
ctx context.Context
cancel context.CancelFunc
)
if s.timeout > 0 {
ctx, cancel = context.WithTimeout(request.Context(), s.timeout)
} else {
ctx, cancel = context.WithCancel(request.Context())
}
defer cancel()
pathTemplate := request.URL.Path
if route := mux.CurrentRoute(request); route != nil {
pathTemplate, _ = route.GetPathTemplate()
}
tr := &Transport{}
})
}
}
func (s *Server) Endpoint() string {
//TODO implement me
panic("implement me")
}
func (s *Server) Operation() string {
//TODO implement me
panic("implement me")
}
func (s *Server) RequestHeader() transport.Header {
//TODO implement me
panic("implement me")
}
func (s *Server) ReplyHeader() transport.Header {
//TODO implement me
panic("implement me")
}

@ -0,0 +1,65 @@
package http
import (
"context"
"git.diulo.com/mogfee/kit/transport"
"net/http"
)
type Transporter interface {
transport.Transporter
Request() *http.Request
PathTemplate() string
}
type Transport struct {
endpoint string
operation string
reqHeader headerCarrier
replyHeader headerCarrier
request *http.Request
pathTemplate string
}
func (t *Transport) Kind() transport.Kind {
return transport.KindHTTP
}
func (t *Transport) Endpoint() string {
return t.endpoint
}
func (t *Transport) Operation() string {
return t.operation
}
func (t *Transport) RequestHeader() transport.Header {
return t.reqHeader
}
func (t *Transport) ReplyHeader() transport.Header {
return t.replyHeader
}
func (t *Transport) Request() *http.Request {
return t.request
}
func (t *Transport) PathTemplate() string {
return t.pathTemplate
}
func SetOperation(ctx context.Context, op string) {
if tr, ok := transport.FromServerContext(ctx); ok {
if tr, ok := tr.(*Transport); ok {
tr.operation = op
}
}
}
func RequestFromServerContext(ctx context.Context) (*http.Request, bool) {
if tr, ok := transport.FromServerContext(ctx); ok {
if tr, ok := tr.(*Transport); ok {
return tr.request, true
}
}
return nil, false
}

@ -0,0 +1,58 @@
package transport
import (
"context"
"net/url"
)
type Server interface {
Start(context.Context) error
Stop(context.Context) error
}
type Endpointer interface {
Endpoint() (*url.URL, error)
}
type Header interface {
Get(key string) string
Set(key string, value string)
Keys() []string
}
type Transporter interface {
Kind() Kind
Endpoint() string
Operation() string
RequestHeader() Header
ReplyHeader() Header
}
type Kind string
func (k Kind) String() string {
return string(k)
}
const (
KindGRPC Kind = "grpc"
KindHTTP Kind = "http"
)
type (
serverTransportKey struct{}
clientTransportKey struct{}
)
func NewServerContext(ctx context.Context, tr Transporter) context.Context {
return context.WithValue(ctx, serverTransportKey{}, tr)
}
func FromServerContext(ctx context.Context) (tr Transporter, ok bool) {
tr, ok = ctx.Value(serverTransportKey{}).(Transporter)
return
}
func NewClientContext(ctx context.Context, tr Transporter) context.Context {
return context.WithValue(ctx, clientTransportKey{}, tr)
}
func FromClientContext(ctx context.Context) (tr Transporter, ok bool) {
tr, ok = ctx.Value(clientTransportKey{}).(Transporter)
return
}
Loading…
Cancel
Save