You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

192 lines
3.7 KiB

package search
import (
"errors"
"fmt"
)
const (
colon = ':'
slash = '/'
)
var (
errDupItem = errors.New("duplicated item")
errDupSlash = errors.New("duplicated slash")
errEmptyItem = errors.New("empty item")
errInvalidState = errors.New("search tree is in an invalid state")
errNotFromRoot = errors.New("path should start with /")
NotFound Result
)
type (
innerResult struct {
key string
value string
named bool
found bool
}
node struct {
item any
children [2]map[string]*node
}
TreeNode struct {
root *node
}
Result struct {
Item any
Params map[string]string
}
)
func NewTree() *TreeNode {
return &TreeNode{
root: newNode(nil),
}
}
func (t *TreeNode) Add(route string, item any) error {
if len(route) == 0 || route[0] != slash {
return errNotFromRoot
}
if item == nil {
return errEmptyItem
}
err := add(t.root, route[1:], item)
switch err {
case errDupItem:
return duplicatedItem(route)
case errDupSlash:
return duplicatedSlash(route)
default:
return err
}
}
func (t *TreeNode) Search(route string) (Result, bool) {
if len(route) == 0 || route[0] != slash {
return NotFound, false
}
var result Result
ok := t.next(t.root, route[1:], &result)
return result, ok
}
func (t *TreeNode) next(n *node, route string, result *Result) bool {
if len(route) == 0 && n.item != nil {
result.Item = n.item
return true
}
for i := range route {
if route[i] != slash {
continue
}
token := route[:i]
return n.foreach(func(k string, v *node) bool {
r := match(k, token)
if !r.found || !t.next(v, route[i+1:], result) {
return false
}
if r.named {
addParam(result, r.key, r.value)
}
return true
})
}
return n.foreach(func(k string, v *node) bool {
if r := match(k, route); r.found && v.item != nil {
result.Item = v.item
if r.named {
addParam(result, r.key, r.value)
}
return true
}
return false
})
}
func (nd *node) foreach(fn func(string, *node) bool) bool {
for _, children := range nd.children {
for k, v := range children {
if fn(k, v) {
return true
}
}
}
return false
}
func add(nd *node, route string, item any) error {
if len(route) == 0 {
if nd.item != nil {
return errDupItem
}
nd.item = item
return nil
}
if route[0] == slash {
return errDupSlash
}
for i := range route {
if route[i] != slash {
continue
}
token := route[:i]
children := nd.getChildren(token)
if child, ok := children[token]; ok {
if child != nil {
return add(child, route[i+1:], item)
}
return errInvalidState
}
child := newNode(nil)
children[token] = child
return add(child, route[i+1:], item)
}
children := nd.getChildren(route)
if child, ok := children[route]; ok {
if child.item != nil {
return errDupItem
}
child.item = item
} else {
children[route] = newNode(item)
}
return nil
}
func (nd *node) getChildren(route string) map[string]*node {
if len(route) > 0 && route[0] == colon {
return nd.children[1]
}
return nd.children[0]
}
func duplicatedItem(item string) error {
return fmt.Errorf("duplicated item for %s", item)
}
func duplicatedSlash(item string) error {
return fmt.Errorf("duplicated slash for %s", item)
}
func addParam(result *Result, k, v string) {
if result.Params == nil {
result.Params = make(map[string]string)
}
result.Params[k] = v
}
func match(pat, token string) innerResult {
if pat[0] == colon {
return innerResult{
key: pat[1:],
value: token,
named: true,
found: true,
}
}
return innerResult{
found: pat == token,
}
}
func newNode(item any) *node {
return &node{
item: item,
children: [2]map[string]*node{
make(map[string]*node),
make(map[string]*node),
},
}
}