Răsfoiți Sursa

fix:之前的流式hash实现会导致idf编译报错

doudou 2 ani în urmă
părinte
comite
0a80fb833c

+ 12 - 4
components/crypto/luat_crypto_mbedtls.c

@@ -274,17 +274,25 @@ int luat_crypto_md_stream_init(const char* md, const char* key, luat_crypt_strea
     return 0;
 }
 
-int luat_crypto_md_stream_update(const char* str, size_t str_size, luat_crypt_stream_t *stream) {
+int luat_crypto_md_stream_update(const char * md, const char* str, size_t str_size, luat_crypt_stream_t *stream) {
+    const mbedtls_md_info_t * info = mbedtls_md_info_from_string(md);
+    if (info == NULL) {
+        return -1;
+    }
     if (stream->key_len > 0){
         mbedtls_md_hmac_update(stream->ctx, (const unsigned char*)str, str_size);
     }
     else {
-        mbedtls_md_update(stream->ctx, str, str_size);
+        mbedtls_md_update(stream->ctx, (const unsigned char*)str, str_size);
     }
     return 0;
 }
 
-int luat_crypto_md_stream_finish(void* out_ptr, luat_crypt_stream_t *stream) {
+int luat_crypto_md_stream_finish(const char* md, void* out_ptr, luat_crypt_stream_t *stream) {
+    const mbedtls_md_info_t * info = mbedtls_md_info_from_string(md);
+    if (info == NULL) {
+        return -1;
+    }
     int ret = 0;
 
     if (stream->key_len > 0) {
@@ -294,7 +302,7 @@ int luat_crypto_md_stream_finish(void* out_ptr, luat_crypt_stream_t *stream) {
         ret = mbedtls_md_finish(stream->ctx, out_ptr);
     }
     if (ret == 0) {
-        unsigned char size = mbedtls_md_get_size(stream->ctx->md_info);
+        unsigned char size = mbedtls_md_get_size(info);
         mbedtls_md_free(stream->ctx);
         luat_heap_free(stream->ctx);
         stream->ctx = NULL;

+ 2 - 2
luat/include/luat_crypto.h

@@ -40,6 +40,6 @@ int luat_crypto_md(const char* md, const char* str, size_t str_size, void* out_p
 int luat_crypto_md_file(const char* md, void* out_ptr, const char* key, size_t key_len, const char* path);
 
 int luat_crypto_md_stream_init(const char* md, const char* key, luat_crypt_stream_t *stream);
-int luat_crypto_md_stream_update(const char* str, size_t str_size, luat_crypt_stream_t *stream);
-int luat_crypto_md_stream_finish(void* out_ptr, luat_crypt_stream_t *stream);
+int luat_crypto_md_stream_update(const char* md, const char* str, size_t str_size, luat_crypt_stream_t *stream);
+int luat_crypto_md_stream_finish(const char* md, void* out_ptr, luat_crypt_stream_t *stream);
 #endif

+ 15 - 12
luat/modules/luat_lib_crypto.c

@@ -666,7 +666,7 @@ static int l_crypto_md(lua_State *L) {
 
 /*
 创建流式hash用的stream
-@api crypto.hash_stream_init()
+@api crypto.hash_stream_init(tp)
 @string hash类型, 大写字母, 例如 "MD5" "SHA1" "SHA256"
 @string hmac值,可选
 @return userdata 成功返回一个数据结构,否则返回nil
@@ -691,7 +691,7 @@ static int l_crypt_hash_stream_init(lua_State *L) {
         const char* key = NULL;
         const char* md = luaL_checkstring(L, 1);
         if(lua_type(L, 2) == LUA_TSTRING) {
-            key = luaL_checklstring(L, 2, &(stream->key_len));
+            key = luaL_checklstring(L, 3, &(stream->key_len));
         }
         int ret = luat_crypto_md_stream_init(md, key, stream);
         if (ret < 0) {
@@ -705,35 +705,38 @@ static int l_crypt_hash_stream_init(lua_State *L) {
 
 /*
 流式hash更新数据
-@api crypto.hash_stream_update(stream, data)
+@api crypto.hash_stream_update(tp, stream, data)
+@string hash类型, 大写字母, 例如 "MD5" "SHA1" "SHA256"
 @userdata crypto.hash_stream_init()创建的stream, 必选
 @string 待计算的数据,必选
 @return 无
 @usage
-crypto.hash_stream_update(stream, "OK")
+crypto.hash_stream_update("MD5", stream, "OK")
 */
 static int l_crypt_hash_stream_update(lua_State *L) {
-    luat_crypt_stream_t *stream = (luat_crypt_stream_t *)luaL_checkudata(L, 1, LUAT_CRYPTO_TYPE);
-    const char *data = NULL;
+    const char* md = luaL_checkstring(L, 1);
+    luat_crypt_stream_t *stream = (luat_crypt_stream_t *)luaL_checkudata(L, 2, LUAT_CRYPTO_TYPE);
     size_t data_len = 0;
-    data = luaL_checklstring(L, 2, &data_len);
-    luat_crypto_md_stream_update(data, data_len ,stream);
+    const char *data = luaL_checklstring(L, 3, &data_len);
+    luat_crypto_md_stream_update(md, data, data_len ,stream);
     return 0;
 }
 
 /*
 获取流式hash校验值并释放创建的stream
-@api crypto.l_crypt_hash_stream_finish(stream)
+@api crypto.l_crypt_hash_stream_finish(tp, stream)
+@string hash类型, 大写字母, 例如 "MD5" "SHA1" "SHA256"
 @userdata crypto.hash_stream_init()创建的stream,必选
 @return string 成功返回计算得出的流式hash值的hex字符串,失败无返回
 @usage
-local hashResult = crypto.hash_stream_finish(stream)
+local hashResult = crypto.hash_stream_finish("MD5", stream)
 */
 static int l_crypt_hash_stream_finish(lua_State *L) {
-    luat_crypt_stream_t *stream = (luat_crypt_stream_t *)luaL_checkudata(L, 1, LUAT_CRYPTO_TYPE);
+    const char* md = luaL_checkstring(L, 1);
+    luat_crypt_stream_t *stream = (luat_crypt_stream_t *)luaL_checkudata(L, 2, LUAT_CRYPTO_TYPE);
     char buff[128] = {0};
     char output[64];
-    int ret = luat_crypto_md_stream_finish(output, stream);
+    int ret = luat_crypto_md_stream_finish(md, output, stream);
     LLOGD("finish result %d", ret);
     if (ret < 1) {
         return 0;