You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
2.5 KiB
108 lines
2.5 KiB
1 year ago
|
package ddl
|
||
|
|
||
|
import (
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"git.diulo.com/mogfee/kit/mysql/parser"
|
||
|
"github.com/antlr4-go/antlr/v4"
|
||
|
"os"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
func Parser(fileName string, fun func(table *Table) error) error {
|
||
|
body, err := os.ReadFile(fileName)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
inputStream := antlr.NewInputStream(string(body))
|
||
|
caseStream := newCaseChangingStream(inputStream, true)
|
||
|
lexer := parser.NewMySqlLexer(caseStream)
|
||
|
lexer.RemoveErrorListeners()
|
||
|
tokens := antlr.NewCommonTokenStream(lexer, antlr.LexerDefaultTokenChannel)
|
||
|
mysqlParser := parser.NewMySqlParser(tokens)
|
||
|
mysqlParser.RemoveErrorListeners()
|
||
|
vis := visitor{}
|
||
|
data := mysqlParser.Root().Accept(&vis)
|
||
|
if createTables, ok := data.([]*Table); ok {
|
||
|
for _, table := range createTables {
|
||
|
if err = fun(table); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type visitor struct {
|
||
|
parser.BaseMySqlParserVisitor
|
||
|
}
|
||
|
|
||
|
func (v *visitor) panicWithExpr(expr any, msg string) {
|
||
|
panic(msg)
|
||
|
}
|
||
|
func (v *visitor) trace(msg ...interface{}) {
|
||
|
fmt.Println("Visit Trace: " + fmt.Sprint(msg...))
|
||
|
}
|
||
|
func (v *visitor) VisitRoot(ctx *parser.RootContext) interface{} {
|
||
|
//v.trace("VisitRoot")
|
||
|
if ctx.SqlStatements() != nil {
|
||
|
return ctx.SqlStatements().Accept(v)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (v *visitor) VisitSqlStatements(ctx *parser.SqlStatementsContext) interface{} {
|
||
|
//v.trace("VisitSqlStatements")
|
||
|
createTables := make([]*Table, 0)
|
||
|
for _, e := range ctx.AllSqlStatement() {
|
||
|
ret := e.Accept(v)
|
||
|
if ret == nil {
|
||
|
continue
|
||
|
}
|
||
|
if table, ok := ret.(*Table); ok {
|
||
|
table.Imports = table.GetImports()
|
||
|
for _, v := range table.Columns {
|
||
|
v.GoType = v.GetTypeStr()
|
||
|
v.Default = v.GetDefault()
|
||
|
}
|
||
|
|
||
|
createTables = append(createTables, table)
|
||
|
}
|
||
|
}
|
||
|
return createTables
|
||
|
}
|
||
|
|
||
|
func (v *visitor) VisitSqlStatement(ctx *parser.SqlStatementContext) interface{} {
|
||
|
//v.trace("VisitSqlStatement")
|
||
|
if ctx.DdlStatement() != nil {
|
||
|
return ctx.DdlStatement().Accept(v)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (v *visitor) VisitDdlStatement(ctx *parser.DdlStatementContext) interface{} {
|
||
|
//v.trace("VisitDdlStatement")
|
||
|
if ctx.CreateTable() != nil {
|
||
|
return v.visitCreateTable(ctx.CreateTable())
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func PrintJson(dat any) {
|
||
|
b, _ := json.MarshalIndent(dat, " ", " ")
|
||
|
fmt.Println(string(b))
|
||
|
}
|
||
|
|
||
|
func GoName(str string) string {
|
||
|
//数字开头
|
||
|
if str[0] >= 48 && str[0] <= 57 {
|
||
|
str = "a" + str
|
||
|
}
|
||
|
arr := strings.Split(str, "_")
|
||
|
newArr := []string{}
|
||
|
for _, v := range arr {
|
||
|
newArr = append(newArr, strings.Title(v))
|
||
|
}
|
||
|
return strings.Join(newArr, "")
|
||
|
}
|