You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
107 lines
2.5 KiB
107 lines
2.5 KiB
package ddl |
|
|
|
import ( |
|
"encoding/json" |
|
"fmt" |
|
"git.diulo.com/mogfee/kit/mysql/parser" |
|
"github.com/antlr4-go/antlr/v4" |
|
"os" |
|
"strings" |
|
) |
|
|
|
func Parser(fileName string, fun func(table *Table) error) error { |
|
body, err := os.ReadFile(fileName) |
|
if err != nil { |
|
return err |
|
} |
|
inputStream := antlr.NewInputStream(string(body)) |
|
caseStream := newCaseChangingStream(inputStream, true) |
|
lexer := parser.NewMySqlLexer(caseStream) |
|
lexer.RemoveErrorListeners() |
|
tokens := antlr.NewCommonTokenStream(lexer, antlr.LexerDefaultTokenChannel) |
|
mysqlParser := parser.NewMySqlParser(tokens) |
|
mysqlParser.RemoveErrorListeners() |
|
vis := visitor{} |
|
data := mysqlParser.Root().Accept(&vis) |
|
if createTables, ok := data.([]*Table); ok { |
|
for _, table := range createTables { |
|
if err = fun(table); err != nil { |
|
return err |
|
} |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
type visitor struct { |
|
parser.BaseMySqlParserVisitor |
|
} |
|
|
|
func (v *visitor) panicWithExpr(expr any, msg string) { |
|
panic(msg) |
|
} |
|
func (v *visitor) trace(msg ...interface{}) { |
|
fmt.Println("Visit Trace: " + fmt.Sprint(msg...)) |
|
} |
|
func (v *visitor) VisitRoot(ctx *parser.RootContext) interface{} { |
|
//v.trace("VisitRoot") |
|
if ctx.SqlStatements() != nil { |
|
return ctx.SqlStatements().Accept(v) |
|
} |
|
return nil |
|
} |
|
|
|
func (v *visitor) VisitSqlStatements(ctx *parser.SqlStatementsContext) interface{} { |
|
//v.trace("VisitSqlStatements") |
|
createTables := make([]*Table, 0) |
|
for _, e := range ctx.AllSqlStatement() { |
|
ret := e.Accept(v) |
|
if ret == nil { |
|
continue |
|
} |
|
if table, ok := ret.(*Table); ok { |
|
table.Imports = table.GetImports() |
|
for _, v := range table.Columns { |
|
v.GoType = v.GetTypeStr() |
|
v.Default = v.GetDefault() |
|
} |
|
|
|
createTables = append(createTables, table) |
|
} |
|
} |
|
return createTables |
|
} |
|
|
|
func (v *visitor) VisitSqlStatement(ctx *parser.SqlStatementContext) interface{} { |
|
//v.trace("VisitSqlStatement") |
|
if ctx.DdlStatement() != nil { |
|
return ctx.DdlStatement().Accept(v) |
|
} |
|
return nil |
|
} |
|
|
|
func (v *visitor) VisitDdlStatement(ctx *parser.DdlStatementContext) interface{} { |
|
//v.trace("VisitDdlStatement") |
|
if ctx.CreateTable() != nil { |
|
return v.visitCreateTable(ctx.CreateTable()) |
|
} |
|
return nil |
|
} |
|
|
|
func PrintJson(dat any) { |
|
b, _ := json.MarshalIndent(dat, " ", " ") |
|
fmt.Println(string(b)) |
|
} |
|
|
|
func GoName(str string) string { |
|
//数字开头 |
|
if str[0] >= 48 && str[0] <= 57 { |
|
str = "a" + str |
|
} |
|
arr := strings.Split(str, "_") |
|
newArr := []string{} |
|
for _, v := range arr { |
|
newArr = append(newArr, strings.Title(v)) |
|
} |
|
return strings.Join(newArr, "") |
|
}
|
|
|