You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

108 lines
2.5 KiB

1 year ago
package ddl
import (
"encoding/json"
"fmt"
"git.diulo.com/mogfee/kit/mysql/parser"
"github.com/antlr4-go/antlr/v4"
"os"
"strings"
)
func Parser(fileName string, fun func(table *Table) error) error {
body, err := os.ReadFile(fileName)
if err != nil {
return err
}
inputStream := antlr.NewInputStream(string(body))
caseStream := newCaseChangingStream(inputStream, true)
lexer := parser.NewMySqlLexer(caseStream)
lexer.RemoveErrorListeners()
tokens := antlr.NewCommonTokenStream(lexer, antlr.LexerDefaultTokenChannel)
mysqlParser := parser.NewMySqlParser(tokens)
mysqlParser.RemoveErrorListeners()
vis := visitor{}
data := mysqlParser.Root().Accept(&vis)
if createTables, ok := data.([]*Table); ok {
for _, table := range createTables {
if err = fun(table); err != nil {
return err
}
}
}
return nil
}
type visitor struct {
parser.BaseMySqlParserVisitor
}
func (v *visitor) panicWithExpr(expr any, msg string) {
panic(msg)
}
func (v *visitor) trace(msg ...interface{}) {
fmt.Println("Visit Trace: " + fmt.Sprint(msg...))
}
func (v *visitor) VisitRoot(ctx *parser.RootContext) interface{} {
//v.trace("VisitRoot")
if ctx.SqlStatements() != nil {
return ctx.SqlStatements().Accept(v)
}
return nil
}
func (v *visitor) VisitSqlStatements(ctx *parser.SqlStatementsContext) interface{} {
//v.trace("VisitSqlStatements")
createTables := make([]*Table, 0)
for _, e := range ctx.AllSqlStatement() {
ret := e.Accept(v)
if ret == nil {
continue
}
if table, ok := ret.(*Table); ok {
table.Imports = table.GetImports()
for _, v := range table.Columns {
v.GoType = v.GetTypeStr()
v.Default = v.GetDefault()
}
createTables = append(createTables, table)
}
}
return createTables
}
func (v *visitor) VisitSqlStatement(ctx *parser.SqlStatementContext) interface{} {
//v.trace("VisitSqlStatement")
if ctx.DdlStatement() != nil {
return ctx.DdlStatement().Accept(v)
}
return nil
}
func (v *visitor) VisitDdlStatement(ctx *parser.DdlStatementContext) interface{} {
//v.trace("VisitDdlStatement")
if ctx.CreateTable() != nil {
return v.visitCreateTable(ctx.CreateTable())
}
return nil
}
func PrintJson(dat any) {
b, _ := json.MarshalIndent(dat, " ", " ")
fmt.Println(string(b))
}
func GoName(str string) string {
//数字开头
if str[0] >= 48 && str[0] <= 57 {
str = "a" + str
}
arr := strings.Split(str, "_")
newArr := []string{}
for _, v := range arr {
newArr = append(newArr, strings.Title(v))
}
return strings.Join(newArr, "")
}