碎語:(請自行跳過)git
距離上篇文章發佈也有半個月的時間了github
而後這半個月 也沒有用go寫項目或繼續學習 只能算簡單的入門了 之後若是有須要的話 或許會深刻的去了解一下這門語言 與各類經常使用的類庫 精力有限 把本身以前嘗試寫一個簡單orm的一些片斷與你們分享sql
也許在月底會嘗試用go去寫一個爬蟲 留待下篇文章分享數據庫
前言:app
關於go的orm框架有許多不錯的 爲何本身想寫一個緣由無非就是想經過寫orm的過程來對本身學習的知識作一個階段性的鍛鍊與檢驗 固然目前寫的這個只能算是一個玩具 若是你能在這個玩具裏有所收穫 那即是最好的了框架
正文:學習
技術需求:對反射有一些瞭解ui
反射能夠簡單的劃分爲如下幾步:
1獲取對象
t := reflect.TypeOf(arg) #獲取類型
v := reflect.ValueOf(arg) #獲取值
2獲取字段(值 或 名稱)
vf := v.Field(i)
fv := v.Field(i).Interface() #獲取值this
3設值
vf.CanSet() #判斷是否能夠設值
vf.setxxx(xx) code
而後插入 刪除 更新 能夠用相同的方法實現 只須要使用到 1 2 步
查詢會用到第 3 步
有了上面的這些知識咱們就能夠嘗試寫出一個orm框架了 閒話很少說上代碼
插入 刪除 與 更新省略
func insert(arg interface{}) (sql []byte, params []interface{}, kIdstr string, err error) { if arg == nil { err = errors.New("expected a pointer to a struct") return } var values []byte //獲取字段 paramsMap, tableName, kIdcolumn, kIdstr := elem(arg) //拼裝sql語句 sql = append(sql, []byte("INSERT "+tableName+" ( ")...) values = append(values, []byte(" VALUES (")...) for colum, v := range paramsMap { if colum != kIdcolumn { sql = append(sql, []byte(" "+colum+" ,")...) values = append(values, []byte(" ? ,")...) //獲取對應參數 params = append(params, v) } } //拼裝成功 sql = append(sql[:len(sql)-1], ')') values = append(values[:len(values)-1], ')') sql = append(sql, values...) log.Println("===>", string(sql), params) return }
func elem(arg interface{}) (paramsMap map[string]interface{}, tableName, kIdcolumn, kIdFiled string) { t := reflect.TypeOf(arg) v := reflect.ValueOf(arg).Elem() //獲取表名 if t.Kind() == reflect.Ptr { t = t.Elem() tableName = t.Name() log.Println("===> tableName:", tableName) } //獲取字段 num := v.NumField() paramsMap = make(map[string]interface{}, num) for i := 0; i < num; i++ { //inteface 方法 非導出字段沒法使用 if v.Field(i).CanInterface() { var tn string //獲取字段的值 fv := v.Field(i).Interface() // 之後能夠改成tag 進行更好的擴展 tf := t.Field(i) dC := tf.Tag.Get(dbColumn) if dC == "" { dC = tf.Tag.Get(dbID) if dC == "" { tn = tf.Name } else { kIdFiled = tf.Name tn = dC kIdcolumn = tn dT := tf.Tag.Get(dbTableName) if dT != "" { tableName = dT } } } else { tn = dC } paramsMap[tn] = fv } else { //此處省略判斷類型進行匹配 //.... } } return }
func (this *Mysql) Insert(obj interface{}) error { query, param, kIdstr, err := insert(obj) if err != nil { return err } result, err := this.Exec(string(query), param...) if err != nil { return err } num, err := result.LastInsertId() if err != nil { return err } v := reflect.ValueOf(obj).Elem() vv := v.FieldByName(kIdstr) if vv.CanSet() { vv.SetInt(num) } return nil }
查詢(目前只支持查詢單條數據 下一版會支持多條)
func selectOne(arg interface{}) (sql []byte, params []interface{}, err error) { if arg == nil { err = errors.New("expected a pointer to a struct") return } //獲取字段 paramsMap, tableName, kIdcolumn, _ := elem(arg) var sqlWhere string //拼裝sql語句 sql = append(sql, []byte("SELECT ")...) for colum, v := range paramsMap { sql = append(sql, []byte(" "+colum+" ,")...) if colum == kIdcolumn { sqlWhere = " WHERE " + colum + " = ? " params = append(params, v) } } sql = sql[:len(sql)-1] sql = append(sql, []byte("FROM "+tableName)...) sql = append(sql, []byte(sqlWhere)...) //拼裝成功 log.Println("===>", string(sql), params) return }
/*2016/06/19/22:35*/ func (this *Mysql) selectOne(obj interface{}, query string, params ...interface{}) (*sql.Rows, error) { if len(params) == 0 { return nil, fmt.Errorf("params is nil") } tx, err := this.DB.Begin() if err != nil { return nil, err } rows, err := tx.Query(query, params...) if err != nil { return nil, err } //進行設值 字段與數據庫對應關係 filedCMap := filedColumnMapper(obj) //設值 須要更多詳細操做 setFiled(obj, rows, filedCMap) err = tx.Commit() if err != nil { return nil, err } return rows, nil }
//設值字段與數據的映射關係 func filedColumnMapper(obj interface{}) map[string]string { t := reflect.TypeOf(obj).Elem() v := reflect.ValueOf(obj).Elem() num := t.NumField() //獲取 字段 對應關係 ----此處應拿到buil-sql中 filedCMap := make(map[string]string, num) for i := 0; i < num; i++ { //inteface 方法 非導出字段沒法使用 if v.Field(i).CanInterface() { var tn string // 之後能夠改成tag 進行更好的擴展 tf := t.Field(i) kC := tf.Tag.Get(dbColumn) if kC == "" { kC = tf.Tag.Get(dbColumn) if kC == "" { tn = tf.Name } else { tn = kC } } else { tn = kC } filedCMap[tn] = tf.Name } else { //此處省略判斷類型進行匹配 //.... } } return filedCMap } //爲字段設值 func setFiled(obj interface{}, rows *sql.Rows, filedCMap map[string]string) { //獲取鍵值對 cols, _ := rows.Columns() buff := make([]interface{}, len(cols)) // 臨時slice data := make([]string, len(cols)) // 存數據slice for i, _ := range buff { buff[i] = &data[i] } for rows.Next() { rows.Scan(buff...) // ...是必須的 } t := reflect.TypeOf(obj).Elem() v := reflect.ValueOf(obj).Elem() for k, values := range data { //根據 colum獲取字段名稱 filedName := filedCMap[cols[k]] //進行設值 if _, ok := t.FieldByName(filedName); ok { vft := v.FieldByName(filedName) switch vft.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: val, err := strconv.ParseInt(values, 10, 64) if err == nil { vft.SetInt(val) } case reflect.String: vft.SetString(values) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: val, err := strconv.ParseUint(values, 10, 64) if err == nil { vft.SetUint(val) } case reflect.Float32, reflect.Float64: val, err := strconv.ParseFloat(values, 64) if err == nil { vft.SetFloat(val) } case reflect.Bool: val, err := strconv.ParseBool(values) if err == nil { vft.SetBool(val) } } } } } func (this *Mysql) SelectOne(obj interface{}) error { query, param, err := selectOne(obj) if err != nil { return err } _, err = this.selectOne(obj, string(query), param...) if err != nil { return err } return nil }
晚些時間會把代碼上傳到github 但願你們指出不足之處 和你們共同進步