Go學習【三】一個簡單的orm

碎語:(請自行跳過)git

距離上篇文章發佈也有半個月的時間了github

而後這半個月 也沒有用go寫項目或繼續學習 只能算簡單的入門了 之後若是有須要的話 或許會深刻的去了解一下這門語言 與各類經常使用的類庫 精力有限 把本身以前嘗試寫一個簡單orm的一些片斷與你們分享sql

也許在月底會嘗試用go去寫一個爬蟲 留待下篇文章分享數據庫

前言:app

關於go的orm框架有許多不錯的 爲何本身想寫一個緣由無非就是想經過寫orm的過程來對本身學習的知識作一個階段性的鍛鍊與檢驗 固然目前寫的這個只能算是一個玩具 若是你能在這個玩具裏有所收穫 那即是最好的了框架

正文:學習

技術需求:對反射有一些瞭解ui

反射能夠簡單的劃分爲如下幾步:
1獲取對象
t := reflect.TypeOf(arg) #獲取類型
v := reflect.ValueOf(arg) #獲取值
2獲取字段(值 或 名稱)
vf := v.Field(i)
fv := v.Field(i).Interface() #獲取值this

3設值
vf.CanSet() #判斷是否能夠設值
vf.setxxx(xx) code

而後插入 刪除 更新 能夠用相同的方法實現 只須要使用到 1 2 步

查詢會用到第 3 步

有了上面的這些知識咱們就能夠嘗試寫出一個orm框架了 閒話很少說上代碼

插入 刪除 與 更新省略

func insert(arg interface{}) (sql []byte, params []interface{}, kIdstr string, err error) {
    if arg == nil {
        err = errors.New("expected a pointer to a struct")
        return
    }
    var values []byte
    //獲取字段
    paramsMap, tableName, kIdcolumn, kIdstr := elem(arg)
    //拼裝sql語句
    sql = append(sql, []byte("INSERT "+tableName+" ( ")...)
    values = append(values, []byte(" VALUES (")...)
    for colum, v := range paramsMap {
        if colum != kIdcolumn {
            sql = append(sql, []byte(" "+colum+" ,")...)
            values = append(values, []byte(" ? ,")...)
            //獲取對應參數
            params = append(params, v)
        }

    }
    //拼裝成功
    sql = append(sql[:len(sql)-1], ')')
    values = append(values[:len(values)-1], ')')
    sql = append(sql, values...)
    log.Println("===>", string(sql), params)
    return
}
func elem(arg interface{}) (paramsMap map[string]interface{}, tableName, kIdcolumn, kIdFiled string) {
    t := reflect.TypeOf(arg)
    v := reflect.ValueOf(arg).Elem()
    //獲取表名
    if t.Kind() == reflect.Ptr {
        t = t.Elem()
        tableName = t.Name()
        log.Println("===> tableName:", tableName)
    }
    //獲取字段
    num := v.NumField()
    paramsMap = make(map[string]interface{}, num)
    for i := 0; i < num; i++ {
        //inteface 方法 非導出字段沒法使用
        if v.Field(i).CanInterface() {
            var tn string
            //獲取字段的值
            fv := v.Field(i).Interface()
            // 之後能夠改成tag 進行更好的擴展
            tf := t.Field(i)
            dC := tf.Tag.Get(dbColumn)
            if dC == "" {
                dC = tf.Tag.Get(dbID)
                if dC == "" {
                    tn = tf.Name
                } else {
                    kIdFiled = tf.Name
                    tn = dC
                    kIdcolumn = tn
                    dT := tf.Tag.Get(dbTableName)
                    if dT != "" {
                        tableName = dT
                    }
                }
            } else {
                tn = dC
            }
            paramsMap[tn] = fv
        } else {
            //此處省略判斷類型進行匹配
            //....
        }
    }
    return
}
func (this *Mysql) Insert(obj interface{}) error {
    query, param, kIdstr, err := insert(obj)
    if err != nil {
        return err
    }
    result, err := this.Exec(string(query), param...)
    if err != nil {
        return err
    }
    num, err := result.LastInsertId()
    if err != nil {
        return err
    }
    v := reflect.ValueOf(obj).Elem()
    vv := v.FieldByName(kIdstr)
    if vv.CanSet() {
        vv.SetInt(num)
    }
    return nil

}

查詢(目前只支持查詢單條數據 下一版會支持多條)

func selectOne(arg interface{}) (sql []byte, params []interface{}, err error) {
    if arg == nil {
        err = errors.New("expected a pointer to a struct")
        return
    }
    //獲取字段
    paramsMap, tableName, kIdcolumn, _ := elem(arg)
    var sqlWhere string
    //拼裝sql語句
    sql = append(sql, []byte("SELECT ")...)
    for colum, v := range paramsMap {
        sql = append(sql, []byte(" "+colum+" ,")...)
        if colum == kIdcolumn {
            sqlWhere = " WHERE " + colum + " = ? "
            params = append(params, v)
        }

    }
    sql = sql[:len(sql)-1]
    sql = append(sql, []byte("FROM "+tableName)...)
    sql = append(sql, []byte(sqlWhere)...)
    //拼裝成功
    log.Println("===>", string(sql), params)
    return
}
/*2016/06/19/22:35*/
func (this *Mysql) selectOne(obj interface{}, query string, params ...interface{}) (*sql.Rows, error) {
    if len(params) == 0 {
        return nil, fmt.Errorf("params is nil")
    }
    tx, err := this.DB.Begin()
    if err != nil {
        return nil, err
    }
    rows, err := tx.Query(query, params...)
    if err != nil {
        return nil, err
    }
    //進行設值 字段與數據庫對應關係
    filedCMap := filedColumnMapper(obj)
    //設值 須要更多詳細操做
    setFiled(obj, rows, filedCMap)

    err = tx.Commit()
    if err != nil {
        return nil, err
    }
    return rows, nil
}
//設值字段與數據的映射關係
func filedColumnMapper(obj interface{}) map[string]string {
    t := reflect.TypeOf(obj).Elem()
    v := reflect.ValueOf(obj).Elem()
    num := t.NumField()
    //獲取 字段 對應關係 ----此處應拿到buil-sql中
    filedCMap := make(map[string]string, num)
    for i := 0; i < num; i++ {
        //inteface 方法 非導出字段沒法使用
        if v.Field(i).CanInterface() {
            var tn string
            // 之後能夠改成tag 進行更好的擴展
            tf := t.Field(i)
            kC := tf.Tag.Get(dbColumn)
            if kC == "" {
                kC = tf.Tag.Get(dbColumn)
                if kC == "" {
                    tn = tf.Name
                } else {
                    tn = kC
                }
            } else {
                tn = kC
            }
            filedCMap[tn] = tf.Name
        } else {
            //此處省略判斷類型進行匹配
            //....
        }
    }
    return filedCMap
}

//爲字段設值
func setFiled(obj interface{}, rows *sql.Rows, filedCMap map[string]string) {
    //獲取鍵值對
    cols, _ := rows.Columns()
    buff := make([]interface{}, len(cols)) // 臨時slice
    data := make([]string, len(cols))      // 存數據slice
    for i, _ := range buff {
        buff[i] = &data[i]
    }
    for rows.Next() {
        rows.Scan(buff...) // ...是必須的
    }
    t := reflect.TypeOf(obj).Elem()
    v := reflect.ValueOf(obj).Elem()
    for k, values := range data {
        //根據 colum獲取字段名稱
        filedName := filedCMap[cols[k]]
        //進行設值
        if _, ok := t.FieldByName(filedName); ok {
            vft := v.FieldByName(filedName)
            switch vft.Kind() {
            case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
                val, err := strconv.ParseInt(values, 10, 64)
                if err == nil {
                    vft.SetInt(val)
                }
            case reflect.String:
                vft.SetString(values)
            case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
                val, err := strconv.ParseUint(values, 10, 64)
                if err == nil {
                    vft.SetUint(val)
                }
            case reflect.Float32, reflect.Float64:
                val, err := strconv.ParseFloat(values, 64)
                if err == nil {
                    vft.SetFloat(val)
                }
            case reflect.Bool:
                val, err := strconv.ParseBool(values)
                if err == nil {
                    vft.SetBool(val)
                }

            }
        }

    }
}

func (this *Mysql) SelectOne(obj interface{}) error {
    query, param, err := selectOne(obj)
    if err != nil {
        return err
    }
    _, err = this.selectOne(obj, string(query), param...)

    if err != nil {
        return err
    }
    return nil
}

晚些時間會把代碼上傳到github 但願你們指出不足之處 和你們共同進步

相關文章
相關標籤/搜索