package ddl import ( "encoding/json" "fmt" "git.diulo.com/mogfee/kit/mysql/parser" "github.com/antlr4-go/antlr/v4" "os" "strings" ) func Parser(fileName string, fun func(table *Table) error) error { body, err := os.ReadFile(fileName) if err != nil { return err } inputStream := antlr.NewInputStream(string(body)) caseStream := newCaseChangingStream(inputStream, true) lexer := parser.NewMySqlLexer(caseStream) lexer.RemoveErrorListeners() tokens := antlr.NewCommonTokenStream(lexer, antlr.LexerDefaultTokenChannel) mysqlParser := parser.NewMySqlParser(tokens) mysqlParser.RemoveErrorListeners() vis := visitor{} data := mysqlParser.Root().Accept(&vis) if createTables, ok := data.([]*Table); ok { for _, table := range createTables { if err = fun(table); err != nil { return err } } } return nil } type visitor struct { parser.BaseMySqlParserVisitor } func (v *visitor) panicWithExpr(expr any, msg string) { panic(msg) } func (v *visitor) trace(msg ...interface{}) { fmt.Println("Visit Trace: " + fmt.Sprint(msg...)) } func (v *visitor) VisitRoot(ctx *parser.RootContext) interface{} { //v.trace("VisitRoot") if ctx.SqlStatements() != nil { return ctx.SqlStatements().Accept(v) } return nil } func (v *visitor) VisitSqlStatements(ctx *parser.SqlStatementsContext) interface{} { //v.trace("VisitSqlStatements") createTables := make([]*Table, 0) for _, e := range ctx.AllSqlStatement() { ret := e.Accept(v) if ret == nil { continue } if table, ok := ret.(*Table); ok { table.Imports = table.GetImports() for _, v := range table.Columns { v.GoType = v.GetTypeStr() v.Default = v.GetDefault() } createTables = append(createTables, table) } } return createTables } func (v *visitor) VisitSqlStatement(ctx *parser.SqlStatementContext) interface{} { //v.trace("VisitSqlStatement") if ctx.DdlStatement() != nil { return ctx.DdlStatement().Accept(v) } return nil } func (v *visitor) VisitDdlStatement(ctx *parser.DdlStatementContext) interface{} { //v.trace("VisitDdlStatement") if ctx.CreateTable() != nil { return v.visitCreateTable(ctx.CreateTable()) } return nil } func PrintJson(dat any) { b, _ := json.MarshalIndent(dat, " ", " ") fmt.Println(string(b)) } func GoName(str string) string { //数字开头 if str[0] >= 48 && str[0] <= 57 { str = "a" + str } arr := strings.Split(str, "_") newArr := []string{} for _, v := range arr { newArr = append(newArr, strings.Title(v)) } return strings.Join(newArr, "") }