LOADING

加载过慢请开启缓存 浏览器默认开启

hi3861与智谱ai的realtime模型对话(成功返回文本)

2026/2/24

做了一点不想做了,看了一些毕设的评论突然很着急,为什么他们进度那么快啊,我们学校都没什么通知,只是在1月30号让我们申请选题。为什么有的三月已经准备完成了,为什么有的上学期完成了,有点急啊。

唉,昨天发现一个喜欢的up主注销账号了,明明14号的时候才关注我,那时候真的很开心。新年也结束了,悲伤的氛围。又要找工作。又要准备毕设。真的烦死了。

static_library("i2s_demo") {
      sources = [
          # "i2s_demo.c",
          "realtime_streaming_demo.c",
          "i2s_codec_common.c",
          "i2s_microphone.c",
          "ws_client_demo.c",
          "//third_party/musl/src/math/sinf.c",
          "//third_party/musl/src/math/cosf.c",
          "//third_party/musl/src/math/__sindf.c",
          "//third_party/musl/src/math/__cosdf.c",
          "//third_party/musl/src/math/__rem_pio2f.c",
          "//third_party/musl/src/math/__rem_pio2_large.c",
          "//third_party/musl/src/math/__sin.c",
          "//third_party/musl/src/math/__cos.c",
          "//third_party/musl/src/math/__rem_pio2.c",
      ]

     include_dirs = [
         ".",
         "//device/soc/hisilicon/hi3861v100/sdk_liteos/include",
         "//base/iothardware/peripheral/interfaces/inner_api",
         "//commonlibrary/utils_lite/include",
         "//kernel/liteos_m/kal/include",
         "//foundation/communication/interfaces/kits/wifi_lite/wifiservice",
         "//device/soc/hisilicon/hi3861v100/sdk_liteos/third_party/lwip_sack/include/lwip",
         "//device/soc/hisilicon/hi3861v100/sdk_liteos/third_party/mbedtls/include",
         "//third_party/librws",
         "//third_party/musl/src/internal",
         "//third_party/musl/porting/liteos_m/kernel/src/internal/"
     ]

     defines = [ "WITH_LWIP" ]

     # 链接 librws 静态库和 mbedtls
     deps = [
         "//third_party/librws/librws:librws_static",
         "//device/soc/hisilicon/hi3861v100/sdk_liteos/third_party/mbedtls:mbedtls",
     ]

}

realtime_streaming_demo.c

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <ohos_init.h>
#include "hi_types_base.h"
#include "hi_config.h"
#include "hi_i2s.h"
#include "hi_i2c.h"
#include "hi_gpio.h"
#include "hi_io.h"
#include "hi_time.h"
#include "hi_mem.h"
#include "hi_dma.h"
#include "hi_stdlib.h"
#include "hi_wifi_api.h"
#include "lwip/netif.h"
#include "lwip/netifapi.h"
#include "lwip/ip_addr.h"
#include "librws.h"
#include "i2s_microphone.h"
#include "ws_client_demo.h"
#include "i2s_codec_common.h"
#include "mbedtls/base64.h"
#include "cmsis_os2.h"

#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

#define WIFI_SSID "ZTE-cCCKEK"
#define WIFI_PASSWORD "13471551659abc"

#define I2S_AUDIO_CHUNK_SIZE         320     

#define ES8311_RESET_REG00 0x00
#define ES8311_CLK_MANAGER_REG01 0x01
#define ES8311_CLK_MANAGER_REG02 0x02
#define ES8311_CLK_MANAGER_REG03 0x03
#define ES8311_CLK_MANAGER_REG04 0x04
#define ES8311_CLK_MANAGER_REG05 0x05
#define ES8311_CLK_MANAGER_REG06 0x06
#define ES8311_CLK_MANAGER_REG07 0x07
#define ES8311_CLK_MANAGER_REG08 0x08
#define ES8311_SDPIN_REG09 0x09
#define ES8311_SDPOUT_REG0A 0x0A
#define ES8311_SYSTEM_REG0B 0x0B
#define ES8311_SYSTEM_REG0C 0x0C
#define ES8311_SYSTEM_REG0D 0x0D
#define ES8311_SYSTEM_REG0E 0x0E
#define ES8311_SYSTEM_REG0F 0x0F
#define ES8311_SYSTEM_REG10 0x10
#define ES8311_SYSTEM_REG11 0x11
#define ES8311_SYSTEM_REG12 0x12
#define ES8311_SYSTEM_REG13 0x13
#define ES8311_SYSTEM_REG14 0x14
#define ES8311_ADC_REG15 0x15
#define ES8311_ADC_REG16 0x16
#define ES8311_ADC_REG17 0x17
#define ES8311_ADC_REG18 0x18
#define ES8311_ADC_REG19 0x19
#define ES8311_ADC_REG1A 0x1A
#define ES8311_ADC_REG1B 0x1B
#define ES8311_ADC_REG1C 0x1C
#define ES8311_DAC_REG31 0x31
#define ES8311_DAC_REG32 0x32
#define ES8311_DAC_REG33 0x33
#define ES8311_DAC_REG34 0x34
#define ES8311_DAC_REG35 0x35
#define ES8311_DAC_REG37 0x37
#define ES8311_GPIO_REG44 0x44
#define ES8311_GP_REG45 0x45

#define CODEC_DEVICE_ADDR 0x30

hi_u32 realtime_demo_init_i2s(hi_void)
{
    hi_u32 ret;
    hi_codec_attribute codec_attr = {0};
    
    printf("[I2S] I2S 初始化开始...\n");
    
    ret = hi_gpio_init();
    if (ret != HI_ERR_SUCCESS) {
        printf("[I2S] GPIO 初始化失败: 0x%x\n", ret);
        return ret;
    }
    
    hi_io_set_func(HI_IO_NAME_GPIO_0, HI_IO_FUNC_GPIO_0_I2C1_SDA);
    ret |= hi_io_set_func(HI_IO_NAME_GPIO_1, HI_IO_FUNC_GPIO_1_I2C1_SCL);
    ret |= hi_io_set_func(HI_IO_NAME_GPIO_5, HI_IO_FUNC_GPIO_5_I2S0_MCLK);
    ret |= hi_io_set_func(HI_IO_NAME_GPIO_6, HI_IO_FUNC_GPIO_6_I2S0_TX);
    ret |= hi_io_set_func(HI_IO_NAME_GPIO_7, HI_IO_FUNC_GPIO_7_I2S0_BCLK);
    ret |= hi_io_set_func(HI_IO_NAME_GPIO_8, HI_IO_FUNC_GPIO_8_I2S0_WS);
    ret |= hi_io_set_func(HI_IO_NAME_GPIO_11, HI_IO_FUNC_GPIO_11_I2S0_RX);
    
    if (ret != HI_ERR_SUCCESS) {
        printf("[I2S] GPIO 功能设置失败: 0x%x\n", ret);
        return ret;
    }
    
    ret = hi_i2c_init(HI_I2C_IDX_1, 100000);
    if (ret != HI_ERR_SUCCESS) {
        printf("[I2S] I2C 初始化失败: 0x%x\n", ret);
        return ret;
    }
    
    ret = hi_dma_init();
    if (ret != HI_ERR_SUCCESS) {
        printf("[I2S] DMA 初始化失败: 0x%x\n", ret);
        return ret;
    }
    
    codec_attr.sample_rate = HI_CODEC_SAMPLE_RATE_16K;
    codec_attr.resolution = HI_CODEC_RESOLUTION_16BIT;
    
    ret = codec_init(&codec_attr);
    if (ret != HI_ERR_SUCCESS) {
        printf("[I2S] Codec 初始化失败: 0x%x\n", ret);
        return ret;
    }
    
    hi_i2s_attribute i2s_cfg = {0};
    i2s_cfg.sample_rate = HI_I2S_SAMPLE_RATE_16K;
    i2s_cfg.resolution = HI_I2S_RESOLUTION_16BIT;
    
    ret = hi_i2s_init(&i2s_cfg);
    if (ret != HI_ERR_SUCCESS) {
        printf("[I2S] I2S 初始化失败: 0x%x\n", ret);
        return ret;
    }
    
    printf("[I2S] I2S 初始化成功\n");
    printf("[I2S] 采样率: 16kHz\n");
    printf("[I2S] 声道数: 1\n");
    printf("[I2S] 位深: 16 bits\n");
    
    return HI_ERR_SUCCESS;
}

static hi_void test_realtime_echo(hi_void)
{
    hi_u32 ret;

    printf("\n");
    printf("========================================\n");
    printf("  I2S 测试:实时回声(优化版)\n");
    printf("========================================\n\n");

    // 清空缓冲区
    ret = hi_i2s_write(HI_NULL, 0, 0);

    hi_u32 echo_duration_ms = 5000;
    hi_u32 chunk_size = I2S_AUDIO_CHUNK_SIZE; // 10ms块
    hi_u8 *chunk_buf = hi_malloc(HI_MOD_ID_DRV, chunk_size);

    if (chunk_buf == HI_NULL)
    {
        printf("[ECHO] 缓冲区分配失败\n");
        return;
    }

    hi_u32 start_time = hi_get_tick();
    hi_u32 end_time = start_time + echo_duration_ms;
    hi_u32 processed_chunks = 0;
    hi_u32 read_ret, write_ret;

    while (hi_get_tick() < end_time)
    {
        // 使用超时机制,避免无限阻塞
        read_ret = hi_i2s_read(chunk_buf, chunk_size, 50); // 50ms超时

        if (read_ret == HI_ERR_SUCCESS)
        {
            write_ret = hi_i2s_write(chunk_buf, chunk_size, 50); // 50ms超时

            if (write_ret == HI_ERR_SUCCESS)
            {
                processed_chunks++;

                if (processed_chunks % 50 == 0)
                {
                    hi_u32 elapsed = hi_get_tick() - start_time;
                    printf("[ECHO] 处理中: %u ms, %u 个数据块\n", elapsed, processed_chunks);
                }
            }
        }

    }

    hi_free(HI_MOD_ID_DRV, chunk_buf);

    printf("[ECHO] 实时回声测试完成,处理了 %u 个数据块\n", processed_chunks);
    printf("[ECHO] 每个数据块 %u 字节,总数据量 %u 字节\n",
           chunk_size, processed_chunks * chunk_size);
}

hi_void realtime_demo_cleanup_i2s(hi_void)
{
    hi_i2s_deinit();
    hi_i2c_deinit(HI_I2C_IDX_1);
}

void on_ws_audio_received(const hi_u8 *pcm_data, hi_u32 len)
{
    static hi_u32 total_audio_received = 0;
    static hi_u32 total_audio_played = 0;

    total_audio_received += len;
    printf("[AUDIO] 收到音频数据: %u bytes (总计: %u bytes)\n", len, total_audio_received);

    if (pcm_data != HI_NULL && len > 0) {
        hi_u32 offset = 0;
        hi_u32 chunk_size = I2S_AUDIO_CHUNK_SIZE;
        hi_u32 ret;
        hi_u32 retry_count = 0;

        while (offset < len) {
            hi_u32 write_size = (offset + chunk_size > len) ? (len - offset) : chunk_size;
            ret = hi_i2s_write((hi_u8 *)(pcm_data + offset), write_size, 100);

            if (ret == HI_ERR_SUCCESS) {
                offset += write_size;
                total_audio_played += write_size;
                retry_count = 0;
            } else {
                retry_count++;
                printf("[AUDIO] I2S 写入失败,重试 %u/10\n", retry_count);
                hi_sleep(5);

                if (retry_count >= 10) {
                    printf("[AUDIO] I2S 写入失败,跳过此音频块\n");
                    break;
                }
            }
        }

        printf("[AUDIO] 音频播放完成: %u / %u bytes\n", offset, len);
    }
}


static void Realtime_Streaming_Demo_Entry_Task(const char *arg)
{
    hi_u32 ret;

    (void)arg;

    printf("\n");
    printf("========================================\n");
    printf("  实时流传输演示程序\n");
    printf("========================================\n\n");

    printf("[MAIN] 步骤 1: I2S 录音和回放测试\n");
    ret = realtime_demo_init_i2s();
    if (ret != HI_ERR_SUCCESS) {
        printf("[MAIN] I2S 初始化失败,程序退出\n");
        return;
    }
    hi_sleep(1000);
    
    printf("[MAIN] 步骤 2: I2S 实时回声测试\n");
    //test_realtime_echo();

    printf("[MAIN] 步骤 3: 连接 WiFi\n");
    ret = realtime_demo_connect_wifi(WIFI_SSID, WIFI_PASSWORD);
    if (ret == HI_FALSE) {
        printf("[MAIN] WiFi 连接失败,程序退出\n");
        return;
    }

    hi_sleep(2000);

    printf("[MAIN] 步骤 4: 连接 WebSocket\n");
    ws_set_audio_callback(on_ws_audio_received);
    ws_client_init();

    // 等待 WebSocket 连接建立
    printf("[MAIN] 等待 WebSocket 连接建立...\n");
    hi_u32 wait_count = 0;
    while (!ws_is_connected() && wait_count < 100) {
        hi_sleep(100);
        wait_count++;
    }

    if (!ws_is_connected()) {
        printf("[MAIN] WebSocket 连接超时\n");
        return;
    }

    printf("[MAIN] WebSocket 已连接,等待会话配置完成...\n");
    hi_sleep(3000);  // 等待 session.update 和 session.updated 完成

    printf("[MAIN] 步骤 5: 开始发送测试音频\n");
    ws_start_audio_sending();

    // 保持任务运行,等待音频发送完成
    printf("[MAIN] 主任务继续运行...\n");
    while (1) {
        hi_sleep(1000);
        // 可以在这里添加其他逻辑,比如检查连接状态等
    }
}

void realtime_streaming_demo_init(hi_void)
{
    osThreadAttr_t attr;

    attr.name = "realtime_streaming";
    attr.attr_bits = 0U;
    attr.cb_mem = NULL;
    attr.cb_size = 0U;
    attr.stack_mem = NULL;
    attr.stack_size = 0x8000;
    attr.priority = osPriorityNormal;

    if (osThreadNew((osThreadFunc_t)Realtime_Streaming_Demo_Entry_Task, NULL, &attr) == NULL) {
        printf("[MAIN] 创建实时流传输任务失败\n");
    }
}

void Realtime_Streaming_Demo_Entry(void)
{
    osThreadAttr_t attr;

    attr.name = "realtime_streaming";
    attr.attr_bits = 0U;
    attr.cb_mem = NULL;
    attr.cb_size = 0U;
    attr.stack_mem = NULL;
    attr.stack_size = 0x8000;
    attr.priority = osPriorityNormal;

    if (osThreadNew((osThreadFunc_t)Realtime_Streaming_Demo_Entry_Task, NULL, &attr) == NULL) {
        printf("[MAIN] 创建实时流传输任务失败\n");
    }
}


SYS_RUN(Realtime_Streaming_Demo_Entry);

i2s_codec_common.c

/*
 * I2S 和 Codec 共享函数
 */

#include <stdio.h>
#include <string.h>
#include <ohos_init.h>
#include <hi_types_base.h>
#include <hi_config.h>
#include <hi_i2s.h>
#include <hi_i2c.h>
#include <hi_gpio.h>
#include <hi_io.h>
#include <hi_time.h>
#include <hi_mem.h>
#include <hi_dma.h>
#include <hi_stdlib.h>
#include "hi_wifi_api.h"
#include "lwip/netif.h"
#include "lwip/netifapi.h"
#include "lwip/ip_addr.h"

// Codec 寄存器定义
#define ES8311_RESET_REG00 0x00
#define ES8311_CLK_MANAGER_REG01 0x01
#define ES8311_CLK_MANAGER_REG02 0x02
#define CODEC_DEVICE_ADDR 0x30

typedef enum {
    HI_CODEC_SAMPLE_RATE_8K = 8,
    HI_CODEC_SAMPLE_RATE_16K = 16,
    HI_CODEC_SAMPLE_RATE_24K = 24,
    HI_CODEC_SAMPLE_RATE_32K = 32,
    HI_CODEC_SAMPLE_RATE_48K = 48,
} hi_codec_sample_rate;

typedef enum {
    HI_CODEC_RESOLUTION_16BIT = 16,
} hi_codec_resolution;

typedef struct {
    hi_codec_sample_rate sample_rate;
    hi_codec_resolution resolution;
} hi_codec_attribute;

// 全局网络接口
struct netif *g_lwip_netif = NULL;

/**
 * 写入 Codec 寄存器
 */
hi_u32 codec_write_reg(hi_u8 reg, hi_u8 val)
{
    hi_i2c_data i2c_data;
    hi_u8 send_data[2] = { reg, val };
    i2c_data.send_buf = send_data;
    i2c_data.send_len = 2;
    return hi_i2c_write(HI_I2C_IDX_1, CODEC_DEVICE_ADDR, &i2c_data);
}

/**
 * 设置 Codec 增益
 */
hi_u32 codec_set_gain(hi_void)
{
    hi_u32 ret = HI_ERR_SUCCESS;
    ret |= codec_write_reg(0x0E, 0x02);
    ret |= codec_write_reg(0x0F, 0x44);
    ret |= codec_write_reg(0x15, 0x40);
    ret |= codec_write_reg(0x1B, 0x0A);
    ret |= codec_write_reg(0x1C, 0x6A);
    ret |= codec_write_reg(0x17, 0xBF);
    ret |= codec_write_reg(0x37, 0x48);
    ret |= codec_write_reg(0x32, 0x84);
    ret |= codec_write_reg(0x16, 0x22);
    ret |= codec_write_reg(0x17, 0xDF);
    ret |= codec_write_reg(0x18, 0x87);
    ret |= codec_write_reg(0x19, 0xFB);
    ret |= codec_write_reg(0x1A, 0x03);
    ret |= codec_write_reg(0x1B, 0xEA);
    return ret;
}

/**
 * 初始化 Codec
 */
hi_u32 codec_init(const hi_codec_attribute *codec_attr)
{
    if (codec_attr == HI_NULL) {
        return HI_ERR_FAILURE;
    }

    hi_u32 ret;
    ret = codec_write_reg(0x44, 0x08);
    hi_udelay(5000);
    ret = codec_write_reg(0x31, 0x40);
    ret |= codec_write_reg(0x00, 0x1F);
    ret |= codec_write_reg(0x45, 0x00);
    ret |= codec_write_reg(0x01, 0x30);
    ret |= codec_write_reg(0x02, 0x10);

    // 设置采样率
    if (codec_attr->sample_rate == HI_CODEC_SAMPLE_RATE_8K) {
        ret |= codec_write_reg(0x02, 0xA0);
    } else if (codec_attr->sample_rate == HI_CODEC_SAMPLE_RATE_16K) {
        ret |= codec_write_reg(0x02, 0x40);
    } else if (codec_attr->sample_rate == HI_CODEC_SAMPLE_RATE_32K) {
        ret |= codec_write_reg(0x02, 0x48); 
    } else if (codec_attr->sample_rate == HI_CODEC_SAMPLE_RATE_48K) {
        ret |= codec_write_reg(0x02, 0x00);
    }

    ret |= codec_write_reg(0x03, 0x10);
    ret |= codec_write_reg(0x16, 0x24);
    ret |= codec_write_reg(0x04, 0x10);
    ret |= codec_write_reg(0x05, 0x00);
    ret |= codec_write_reg(0x0B, 0x00);
    ret |= codec_write_reg(0x0C, 0x00);
    ret |= codec_write_reg(0x10, 0x1F);
    ret |= codec_write_reg(0x11, 0x7F);
    ret |= codec_write_reg(0x00, 0x80);
    hi_udelay(50000);
    ret |= codec_write_reg(0x0D, 0x01);
    ret |= codec_write_reg(0x01, 0x3F);
    ret |= codec_write_reg(0x14, 0x18);
    ret |= codec_write_reg(0x12, 0x00);
    ret |= codec_write_reg(0x13, 0x10);
    if ((codec_attr->resolution == HI_CODEC_RESOLUTION_16BIT))
    {
        /* set adc/dac data format */
        ret |= codec_write_reg(0x09, 0x0C);  /* set dac format=16bit i2s */
        ret |= codec_write_reg(0x0A, 0x0C); /* set adc format=16bit i2s */
    }
    else
    {
        /* set adc/dac data format */
        ret |= codec_write_reg(0x09, 0x00);                /* set dac format=24bit i2s */
        ret |= codec_write_reg(0x0A, 0x00);                /* set adc format=24bit i2s */
    }
    ret |= codec_set_gain();
    ret |= codec_write_reg(0x31, 0x00);
    return ret;
}

/**
 * WiFi 事件处理
 */
void wifi_event_handler(const hi_wifi_event *event)
{
    if (event == NULL) {
        return;
    }

    switch (event->event) {
        case HI_WIFI_EVT_SCAN_DONE:
            printf("WiFi: Scan completed\n");
            break;
        case HI_WIFI_EVT_CONNECTED:
            printf("WiFi: Connected to AP\n");
            printf("WiFi: SSID: %s\n", event->info.wifi_connected.ssid);
            if (g_lwip_netif != NULL) {
                netifapi_dhcp_start(g_lwip_netif);
            }
            break;
        case HI_WIFI_EVT_DISCONNECTED:
            printf("WiFi: Disconnected from AP\n");
            if (g_lwip_netif != NULL) {
                netifapi_dhcp_stop(g_lwip_netif);
                ip4_addr_t st_gw, st_ipaddr, st_netmask;
                IP4_ADDR(&st_gw, 0, 0, 0, 0);
                IP4_ADDR(&st_ipaddr, 0, 0, 0, 0);
                IP4_ADDR(&st_netmask, 0, 0, 0, 0);
                netifapi_netif_set_addr(g_lwip_netif, &st_ipaddr, &st_netmask, &st_gw);
            }
            break;
        default:
            break;
    }
}

/**
 * 连接 WiFi
 */
hi_bool realtime_demo_connect_wifi(const char *ssid, const char *password)
{
    int ret;
    hi_wifi_assoc_request assoc_req = {0};
    char ifname[WIFI_IFNAME_MAX_SIZE + 1] = {0};
    int len = sizeof(ifname);
    const unsigned char wifi_vap_res_num = 2;
    const unsigned char wifi_user_res_num = 2;

    printf("========================================\n");
    printf("  步骤 1: 连接 WiFi\n");
    printf("========================================\n");
    printf("SSID: %s\n", ssid);

    if (hi_wifi_get_init_status() == 1) {
        printf("WiFi: Already initialized\n");
    } else {
        ret = hi_wifi_init(wifi_vap_res_num, wifi_user_res_num);
        if (ret != HISI_OK) {
            printf("WiFi: Init failed, ret = %d\n", ret);
            return HI_FALSE;
        }
        printf("WiFi: Init success\n");
    }

    ret = hi_wifi_sta_start(ifname, &len);
    if (ret != HISI_OK) {
        printf("WiFi: STA start failed, ret = %d\n", ret);
        hi_wifi_deinit();
        return HI_FALSE;
    }
    printf("WiFi: STA start success, ifname: %s\n", ifname);

    ret = hi_wifi_register_event_callback(wifi_event_handler);
    if (ret != HISI_OK) {
        printf("WiFi: Register event callback failed, ret = %d\n", ret);
        hi_wifi_sta_stop();
        hi_wifi_deinit();
        return HI_FALSE;
    }
    printf("WiFi: Event callback registered\n");

    g_lwip_netif = netifapi_netif_find(ifname);
    if (g_lwip_netif == NULL) {
        printf("WiFi: Get netif failed\n");
        hi_wifi_sta_stop();
        hi_wifi_deinit();
        return HI_FALSE;
    }

    strncpy(assoc_req.ssid, ssid, HI_WIFI_MAX_SSID_LEN);
    assoc_req.auth = HI_WIFI_SECURITY_WPA2PSK;
    strncpy(assoc_req.key, password, HI_WIFI_MAX_KEY_LEN);

    ret = hi_wifi_sta_connect(&assoc_req);
    if (ret != HISI_OK) {
        printf("WiFi: Connect failed, ret = %d\n", ret);
        hi_wifi_sta_stop();
        hi_wifi_deinit();
        return HI_FALSE;
    }

    printf("WiFi: Waiting for IP address...\n");
    sleep(10);

    if (g_lwip_netif == NULL) {
        printf("WiFi: Netif is NULL, cannot get IP address\n");
        return HI_FALSE;
    }

    const ip4_addr_t *ipaddr, *netmask, *gw;
    ipaddr = netif_ip4_addr(g_lwip_netif);
    netmask = netif_ip4_netmask(g_lwip_netif);
    gw = netif_ip4_gw(g_lwip_netif);

    if (ipaddr == NULL || netmask == NULL || gw == NULL) {
        printf("WiFi: Failed to get IP address info\n");
        return HI_FALSE;
    }

    printf("WiFi: IP address: %s\n", ip4addr_ntoa(ipaddr));
    printf("WiFi: Netmask: %s\n", ip4addr_ntoa(netmask));
    printf("WiFi: Gateway: %s\n", ip4addr_ntoa(gw));
    printf("\n");

    return HI_TRUE;
}

i2s_codec_common.h

/*
 * I2S 和 Codec 共享函数头文件
 */

#ifndef __I2S_CODEC_COMMON_H__
#define __I2S_CODEC_COMMON_H__

#include "hi_types_base.h"

// Codec 属性类型
typedef enum {
    HI_CODEC_SAMPLE_RATE_8K = 8,
    HI_CODEC_SAMPLE_RATE_16K = 16,
    HI_CODEC_SAMPLE_RATE_24K = 24,
    HI_CODEC_SAMPLE_RATE_32K = 32,
    HI_CODEC_SAMPLE_RATE_48K = 48,
} hi_codec_sample_rate;

typedef enum {
    HI_CODEC_RESOLUTION_16BIT = 16,
} hi_codec_resolution;

typedef struct {
    hi_codec_sample_rate sample_rate;
    hi_codec_resolution resolution;
} hi_codec_attribute;

// 全局网络接口声明
extern struct netif *g_lwip_netif;

/**
 * 写入 Codec 寄存器
 */
hi_u32 codec_write_reg(hi_u8 reg, hi_u8 val);

/**
 * 设置 Codec 增益
 */
hi_u32 codec_set_gain(hi_void);

/**
 * 初始化 Codec
 */
hi_u32 codec_init(const hi_codec_attribute *codec_attr);

/**
 * WiFi 事件处理
 */
void wifi_event_handler(const hi_wifi_event *event);

/**
 * 连接 WiFi
 */
hi_bool realtime_demo_connect_wifi(const char *ssid, const char *password);

#endif /* __I2S_CODEC_COMMON_H__ */

ws_client_demo.c

/*
 * WebSocket Client Demo for Hi3861
 * 连接到本地 PC 代理服务器,通过代理访问智谱AI GLM-Realtime API
 *
 * 功能:
 * 1. 连接到 PC 代理服务器 (ws://PC_IP:8080)
 * 2. 发送会话配置事件
 * 3. 接收服务器响应
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

// 定义数学常量
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

#include <ohos_init.h>
#include "hi_wifi_api.h"
#include "hi_time.h"
#include "lwip/netif.h"
#include "lwip/netifapi.h"
#include "librws.h"
#include "ws_client_demo.h"
#include "mbedtls/base64.h"
#include "cmsis_os2.h"
#include "i2s_microphone.h"

// ==================== 配置区域 (请修改此处) ====================

// TODO: 填写你的 PC IP 地址
// 确保 PC 和设备在同一个局域网,PC 已启动 proxy_server.py
// 获取 PC IP: Windows 上运行 ipconfig, Linux/Mac 上运行 ifconfig
#define PC_PROXY_IP "192.168.1.2"
#define PC_PROXY_PORT 8080

// ==================== 以下代码无需修改 ====================

#define SERVER_HOST PC_PROXY_IP
#define SERVER_PORT PC_PROXY_PORT
#define SERVER_PATH "/"

static rws_socket g_ws_socket = NULL;
static hi_bool g_is_connected = HI_FALSE;
static ws_on_audio_callback g_audio_callback = NULL;

// 音频发送任务控制
static osThreadId_t g_audio_send_task = NULL;
static hi_bool g_audio_send_running = HI_FALSE;

// 优化后的缓冲区,减少内存占用
static hi_u8 g_audio_temp_buf[8192];  // 8KB 足够处理单次 WebSocket 音频消息
static char g_response_text_buf[256];  // 响应文本缓冲区

// 音频发送任务使用的全局缓冲区,避免栈溢出
static hi_u8 g_audio_send_chunk[320];  // 音频数据块
static hi_u8 g_audio_send_base64_buf[512];  // base64编码缓冲区

// 函数前置声明
static hi_void start_audio_sending_task(hi_void);

/**
 * 解码 Base64 音频数据并调用回调 - 暂时禁用以避免内存耗尽
 */
static hi_void process_audio_delta(const char *base64_data)
{
    // 暂时禁用音频播放,避免内存耗尽导致内核崩溃
    // TODO: 实现流式音频播放,避免一次性解码大数据
    printf("[WS] 音频数据已跳过 (大小: %u 字符) - 音频播放功能暂时禁用\n", (hi_u32)strlen(base64_data));
    return;

    /* 原代码已禁用,待实现流式处理后再启用

    if (!g_audio_callback || !base64_data) {
        return;
    }

    static hi_u8 decode_buf[4096];  // 4KB 解码缓冲区,分多次解码
    size_t input_len = strlen(base64_data);
    const char *p = base64_data;
    const char *end = base64_data + input_len;

    // 安全地处理引号
    if (input_len > 0 && *p == '"') {
        p++;
        input_len--;
    }
    if (input_len > 0 && *(end - 1) == '"') {
        end--;
        input_len--;
    }

    // 再次检查长度
    if (input_len == 0) {
        printf("[WS] Empty base64 data\n");
        return;
    }

    // 如果数据太大,分块解码
    const size_t max_base64_chunk = 3000;  // 每次最多解码 3000 字符 Base64
    size_t remaining = input_len;
    const char *current_pos = p;

    while (remaining > 0)
    {
        size_t chunk_size = (remaining > max_base64_chunk) ? max_base64_chunk : remaining;

        size_t output_len = 0;
        int ret = mbedtls_base64_decode(decode_buf, sizeof(decode_buf), &output_len,
                                         (const unsigned char *)current_pos, chunk_size);

        if (ret != 0) {
            printf("[WS] Base64 decode failed, err = %d\n", ret);
            // 继续尝试下一块
        } else if (output_len > 0) {
            // 成功解码,播放音频
            printf("[WS] 播放音频块: %u bytes (remaining: %u)\n", output_len, (hi_u32)remaining);
            g_audio_callback(decode_buf, output_len);
        }

        current_pos += chunk_size;
        remaining -= chunk_size;
    }

    printf("[WS] 音频数据处理完成\n");
    */
}

/* 连接成功回调 */
/* 音频发送任务入口函数 - 使用 I2S 麦克风实时录音 */
static void audio_send_task_entry(void *arg)
{
    printf("[AUDIO] 实时录音任务开始\n");

    // 初始化 I2S 麦克风
    i2s_microphone_init();
    i2s_microphone_start();

    printf("[AUDIO] 等待连接稳定...\n");
    hi_sleep(2000);

    printf("[AUDIO] 开始实时录音并发送...\n");
    printf("[AUDIO] 请对着麦克风说话...\n");

    size_t base64_len;
    hi_u32 frame_count = 0;
    hi_u32 read_size;

    // 持续录音并发送,直到任务被停止
    while (g_audio_send_running && ws_is_connected())
    {
        // 从 I2S 麦克风读取一帧音频数据 (320 bytes = 20ms)
        read_size = i2s_microphone_read_frame(g_audio_send_chunk, sizeof(g_audio_send_chunk));

        if (read_size > 0)
        {
            // Base64 编码
            mbedtls_base64_encode(g_audio_send_base64_buf, sizeof(g_audio_send_base64_buf),
                                      &base64_len, g_audio_send_chunk, read_size);
            g_audio_send_base64_buf[base64_len] = '\0';

            // 发送到服务器
            ws_send_audio_base64((const char *)g_audio_send_base64_buf);

            frame_count++;

            // 每 50 帧打印一次(约 1 秒)
            if (frame_count % 50 == 0)
            {
                printf("[AUDIO] 已发送 %u 帧\n", frame_count);
            }

            // 控制发送速率:每帧 20ms,避免超过 50 QPS
            // 50 QPS = 每 100ms 最多 5 帧 = 每 20ms 1 帧
            hi_sleep(10);  // 10ms 延迟
        }
        else
        {
            // 如果读取失败,稍微等待后重试
            hi_sleep(1);
        }
    }

    printf("[AUDIO] 录音任务结束,共发送 %u 帧\n", frame_count);

    // 停止麦克风
    i2s_microphone_stop();

    g_audio_send_running = HI_FALSE;
    g_audio_send_task = NULL;
}


/* 启动音频发送任务 - 公共函数 */
hi_void ws_start_audio_sending(hi_void)
{
    if (g_audio_send_task != NULL)
    {
        printf("[AUDIO] 任务已在运行\n");
        return;
    }

    // 先设置运行标志,再创建任务(避免竞态条件)
    g_audio_send_running = HI_TRUE;

    // 使用 CMSIS-RTOS v2 API 创建任务
    // Hi3861内存有限,减少栈大小
    osThreadAttr_t attr = {
        .name = "audio_send",
        .stack_size = 0x2000,  // 8KB - 减少栈大小防止内存不足
        .priority = osPriorityNormal,
    };

    g_audio_send_task = osThreadNew((osThreadFunc_t)audio_send_task_entry, NULL, &attr);

    if (g_audio_send_task == NULL)
    {
        printf("[AUDIO] 创建任务失败\n");
        g_audio_send_running = HI_FALSE;
        return;
    }

    printf("[AUDIO] 音频发送任务已创建\n");
}

/* 连接成功回调 - 添加实际发送代码 */
static void on_connected(rws_socket socket)
{
    printf("[WS] connected\n");
    g_is_connected = HI_TRUE;

    // 生成时间戳
    hi_u32 timestamp = hi_get_tick();

    // 使用静态缓冲区避免栈溢出 - 根据文档示例构建完整的 session.update
    static char session_update[1024];
    static char event_id_str[64];

    // 生成 UUID 格式的 event_id(简化版)
    snprintf(event_id_str, sizeof(event_id_str), "evt-%u-%u",
             (hi_u32)(timestamp >> 16), (hi_u32)(timestamp & 0xFFFF));

    hi_u32 msg_len = snprintf(session_update, sizeof(session_update),
        "{\"type\":\"session.update\","
        "\"event_id\":\"%s\","
        "\"client_timestamp\":%u,"
        "\"session\":{"
        "\"model\":\"glm-realtime\","
        "\"modalities\":[\"audio\",\"text\"],"
        "\"voice\":\"tongtong\","
        "\"input_audio_format\":\"pcm16\","
        "\"output_audio_format\":\"pcm\","
        "\"input_audio_noise_reduction\":{\"type\":\"far_field\"},"
        "\"turn_detection\":{"
        "\"type\":\"server_vad\","
        "\"threshold\":0.3,"
        "\"prefix_padding_ms\":300,"
        "\"silence_duration_ms\":800,"
        "\"create_response\":true,"
        "\"interrupt_response\":true"
        "},"
        "\"temperature\":0.8,"
        "\"beta_fields\":{"
        "\"chat_mode\":\"audio\","
        "\"tts_source\":\"e2e\""
        "}"
        "}}",
        event_id_str, timestamp);

    if (msg_len >= sizeof(session_update)) {
        printf("[WS] msg too long\n");
        return;
    }

    if (!rws_socket_send_text(socket, session_update))
    {
        printf("[WS] send failed\n");
        return;
    }
    printf("[WS] session.update sent\n");

    // 等待一小段时间让 session.update 被处理,然后开始发送音频
    // 注意:这里不能直接调用 start_audio_sending_task()
    // 因为我们在网络回调上下文中,需要延迟到单独的任务中执行
}

/* 断开连接回调 */
static void on_disconnected(rws_socket socket)
{
    printf("[WS] disconnected\n");
    g_is_connected = HI_FALSE;
    g_audio_send_running = HI_FALSE;
    g_ws_socket = NULL;
}

/* 接收文本消息回调 */
static void on_received_text(rws_socket socket, const char *text, const unsigned int length)
{
    // 简单解析事件类型 - 减少printf避免栈溢出
    if (strstr(text, "\"type\":"))
    {
        const char *type_start = strstr(text, "\"type\":");
        if (type_start && strlen(type_start) > 7)
        {
            static char event_type[32] = {0};  // 使用静态缓冲区
            const char *p = type_start + 7;

            while (*p == ' ' || *p == ':') p++;

            if (*p == '"') {
                p++;
                hi_u32 i = 0;
                while (*p && *p != '"' && i < sizeof(event_type) - 1) {
                    event_type[i++] = *p++;
                }
                event_type[i] = '\0';
            }

            // 处理关键事件,减少日志输出
            if (strstr(event_type, "response.audio.delta"))
            {
                const char *fields[] = {"\"content\":", "\"audio\":", "\"delta\":"};
                const char *value_start = NULL;
                const char *value_end = NULL;
                hi_u32 len = 0;
                hi_bool found = HI_FALSE;

                for (hi_u32 i = 0; i < 3; i++)
                {
                    const char *field_start = strstr(text, fields[i]);
                    if (field_start)
                    {
                        value_start = strchr(field_start + strlen(fields[i]), '"');
                        if (value_start)
                        {
                            value_start++;
                            value_end = strchr(value_start, '"');
                            if (value_end)
                            {
                                len = value_end - value_start;
                                if (len > 0 && len < sizeof(g_audio_temp_buf))
                                {
                                    found = HI_TRUE;
                                    break;
                                }
                            }
                        }
                    }
                }

                if (found && value_start && value_end)
                {
                    memcpy(g_audio_temp_buf, value_start, len);
                    g_audio_temp_buf[len] = '\0';
                    process_audio_delta((const char *)g_audio_temp_buf);
                }
            }
            else if (strstr(event_type, "response.text.done"))
            {
                const char *text_start = strstr(text, "\"text\":");
                if (text_start)
                {
                    memset_s(g_response_text_buf, sizeof(g_response_text_buf), 0, sizeof(g_response_text_buf));
                    if (sscanf(text_start + 7, "\"%255[^\"]", g_response_text_buf) == 1) {
                        printf("[AI] %s\n", g_response_text_buf);
                    }
                }
            }
            else if (strstr(event_type, "session.updated"))
            {
                printf("[WS] session.updated\n");
            }
            else if (strstr(event_type, "error"))
            {
                printf("[ERR] %.*s\n", 100, text);
            }
        }
    }
}

/* 接收二进制消息回调 (暂不使用) */
static void on_received_bin(rws_socket socket, const void *data, const unsigned int length)
{
    printf("[WS] 收到二进制数据: %u bytes\n", length);
}

/* 发送文本消息 */
hi_void ws_send_text(const char *text)
{
    if (g_ws_socket && g_is_connected) {
        if (rws_socket_send_text(g_ws_socket, text)) {
            printf("[WS] 发送成功: %s\n", text);
        } else {
            printf("[WS] 发送失败\n");
        }
    } else {
        printf("[WS] 未连接,无法发送\n");
    }
}

/* 发送音频数据 (Base64编码) - 修正版本 */
hi_void ws_send_audio_base64(const char *base64_audio)
{
    if (!g_ws_socket || !g_is_connected || !base64_audio)
    {
        return;
    }

    // 使用静态缓冲区避免栈溢出
    static char msg[512];
    static hi_u32 audio_counter = 0;
    hi_u32 timestamp = hi_get_tick();

    hi_u32 msg_len = snprintf(msg, sizeof(msg),
                              "{\"type\":\"input_audio_buffer.append\",\"audio\":\"%s\",\"client_timestamp\":%u}",
                              base64_audio, timestamp);

    if (msg_len >= sizeof(msg) || !rws_socket_send_text(g_ws_socket, msg))
    {
        printf("[WS] send fail #%u\n", audio_counter);
    }
    else if (audio_counter % 100 == 0)
    {
        printf("[WS] sent #%u\n", audio_counter);
    }

    audio_counter++;
}

/* 提交音频缓冲区 - Server VAD 模式下不需要手动提交 */
hi_void ws_commit_audio(hi_void)
{
    if (!g_ws_socket || !g_is_connected)
    {
        printf("[WS] 未连接,无法提交音频\n");
        return;
    }

    printf("[WS] Server VAD 模式: 音频由服务器自动检测和提交,无需手动提交\n");
    printf("[WS] 继续发送音频数据即可,服务器会自动处理\n");
}

/* 初始化 WebSocket 连接 */
hi_void ws_client_init(hi_void)
{
    printf("\n");
    printf("========================================\n");
    printf("  WebSocket 客户端初始化\n");
    printf("  代理服务器: %s:%d\n", SERVER_HOST, SERVER_PORT);
    printf("========================================\n\n");

    // 创建 WebSocket
    g_ws_socket = rws_socket_create();
    if (!g_ws_socket) {
        printf("[WS] 创建 socket 失败\n");
        return;
    }

    // 设置连接参数
    rws_socket_set_url(g_ws_socket, "ws", SERVER_HOST, SERVER_PORT, SERVER_PATH);

    // 设置回调函数
    rws_socket_set_on_connected(g_ws_socket, on_connected);
    rws_socket_set_on_disconnected(g_ws_socket, on_disconnected);
    rws_socket_set_on_received_text(g_ws_socket, on_received_text);
    rws_socket_set_on_received_bin(g_ws_socket, on_received_bin);

    // 启动连接
    printf("[WS] 正在连接到 %s:%d...\n", SERVER_HOST, SERVER_PORT);
    if (!rws_socket_connect(g_ws_socket)) {
        printf("[WS] 连接失败\n");
        rws_error error = rws_socket_get_error(g_ws_socket);
        if (error) {
            printf("[WS] 错误代码: %d\n", rws_error_get_code(error));
            printf("[WS] 错误描述: %s\n", rws_error_get_description(error));
        }
    }

    printf("[WS] WebSocket 连接已启动\n");
}

/* 断开 WebSocket 连接 */
hi_void ws_client_disconnect(hi_void)
{
    printf("[WS] 断开连接...\n");

    // 停止音频发送任务
    g_audio_send_running = HI_FALSE;

    // 等待任务结束并删除
    if (g_audio_send_task != NULL)
    {
        hi_sleep(100);  // 给任务时间清理
        osThreadTerminate(g_audio_send_task);
        g_audio_send_task = NULL;
    }

    if (g_ws_socket) {
        rws_socket_disconnect_and_release(g_ws_socket);
        g_ws_socket = NULL;
    }

    g_is_connected = HI_FALSE;
    printf("[WS] 已断开连接\n");
}

/* 检查连接状态 */
hi_bool ws_is_connected(hi_void)
{
    return g_is_connected && g_ws_socket && rws_socket_is_connected(g_ws_socket);
}

/* 设置音频接收回调 */
hi_void ws_set_audio_callback(ws_on_audio_callback callback)
{
    g_audio_callback = callback;
}


/* 发送测试音频数据(模拟录音) */

ws_client_demo.h

/*
 * WebSocket Client Demo Header
 */

#ifndef __WS_CLIENT_DEMO_H__
#define __WS_CLIENT_DEMO_H__

#include "hi_types_base.h"

/*
 * 初始化 WebSocket 连接
 * 连接到配置的 PC 代理服务器
 */
hi_void ws_client_init(hi_void);

/*
 * 断开 WebSocket 连接
 */
hi_void ws_client_disconnect(hi_void);

/*
 * 发送文本消息
 * @param text: JSON格式的字符串
 */
hi_void ws_send_text(const char *text);

/*
 * 发送音频数据 (Base64编码)
 * @param base64_audio: base64编码的PCM音频数据
 */
hi_void ws_send_audio_base64(const char *base64_audio);

/*
 * 提交音频缓冲区
 * 触发AI处理已上传的音频
 */
hi_void ws_commit_audio(hi_void);

/*
 * 检查连接状态
 * @return: HI_TRUE if connected, HI_FALSE otherwise
 */
hi_bool ws_is_connected(hi_void);

/*
 * 音频数据接收回调
 * @param pcm_data: PCM音频数据
 * @param len: 数据长度
 */
typedef void (*ws_on_audio_callback)(const hi_u8 *pcm_data, hi_u32 len);

/*
 * 设置音频接收回调函数
 * @param callback: 接收到音频数据时的回调函数
 */
hi_void ws_set_audio_callback(ws_on_audio_callback callback);

/*
 * 启动音频发送任务
 * 开始发送测试音频数据到服务器
 */
hi_void ws_start_audio_sending(hi_void);

#endif /* __WS_CLIENT_DEMO_H__ */

i2s_microphone.c

#include "i2s_microphone.h"
#include <hi_i2s.h>
#include <hi_io.h>
#include <hi_types_base.h>
#include <string.h>
#include <stdio.h>

/**
 * @brief I2S 麦克风初始化
 */
hi_void i2s_microphone_init(hi_void)
{
    hi_u32 ret;

    printf("[MIC] I2S 麦克风初始化...\n");

    // 1. 配置 I2S GPIO - 使用正确的 GPIO 引脚
    // 根据 Hi3861 硬件手册,I2S RX 引脚为:
    // GPIO_7: I2S0_BCLK
    // GPIO_8: I2S0_WS
    // GPIO_11: I2S0_RX
    hi_io_set_func(HI_IO_NAME_GPIO_7, HI_IO_FUNC_GPIO_7_I2S0_BCLK);  // BCLK
    hi_io_set_func(HI_IO_NAME_GPIO_8, HI_IO_FUNC_GPIO_8_I2S0_WS);    // WS
    hi_io_set_func(HI_IO_NAME_GPIO_11, HI_IO_FUNC_GPIO_11_I2S0_RX);  // RX DATA

    // 2. 配置 I2S
    hi_i2s_attribute i2s_cfg = {0};
    i2s_cfg.sample_rate = HI_I2S_SAMPLE_RATE_16K;  // 16kHz 采样率
    i2s_cfg.resolution = HI_I2S_RESOLUTION_16BIT;  // 16-bit 位深

    ret = hi_i2s_init(&i2s_cfg);
    if (ret != HI_ERR_SUCCESS) {
        printf("[MIC] I2S 初始化失败: 0x%x\n", ret);
        return;
    }

    printf("[MIC] I2S 初始化成功\n");
    printf("[MIC] 采样率: %u Hz\n", I2S_MIC_SAMPLE_RATE);
    printf("[MIC] 声道数: %u\n", I2S_MIC_CHANNELS);
    printf("[MIC] 位深: %u bits\n", I2S_MIC_BITS_PER_SAMPLE);

    printf("[MIC] 麦克风初始化完成\n");
}

/**
 * @brief I2S 麦克风启动
 */
hi_void i2s_microphone_start(hi_void)
{
    printf("[MIC] I2S 麦克风已启动(自动读取模式)\n");
}

/**
 * @brief I2S 麦克风停止
 */
hi_void i2s_microphone_stop(hi_void)
{
    printf("[MIC] I2S 麦克风已停止\n");
}

/**
 * @brief I2S 麦克风读取一帧数据
 * @param buffer 接收缓冲区
 * @param size 读取大小
 * @return 实际读取的字节数
 */
hi_u32 i2s_microphone_read_frame(hi_u8 *buffer, hi_u32 size)
{
    hi_u32 ret;

    // 使用 hi_i2s_read 读取音频数据
    // 超时时间设置为 20ms(一帧的时长)
    ret = hi_i2s_read(buffer, size, 20);

    if (ret != HI_ERR_SUCCESS) {
        // 超时或错误不算失败,返回0表示暂无数据
        return 0;
    }

    return size;  // 返回请求的大小
}

i2s_microphone.h

#ifndef I2S_MICROPHONE_H
#define I2S_MICROPHONE_H

#ifdef __cplusplus
extern "C" {
#endif

#include <hi_types_base.h>

// I2S 麦克风采集配置
#define I2S_MIC_SAMPLE_RATE        16000       // 采样率 16kHz
#define I2S_MIC_CHANNELS          1            // 单声道
#define I2S_MIC_BITS_PER_SAMPLE   16           // 16-bit
#define I2S_MIC_FRAME_SIZE        320          // 每帧 320 字节 = 20ms @ 16kHz
#define I2S_MIC_BUFFER_SIZE       640          // 缓冲区大小(2 帧)

// I2S 麦克风采集初始化
hi_void i2s_microphone_init(hi_void);

// I2S 麦克风采集一帧数据
hi_u32 i2s_microphone_read_frame(hi_u8 *buffer, hi_u32 size);

// I2S 麦克风启动
hi_void i2s_microphone_start(hi_void);

// I2S 麦克风停止
hi_void i2s_microphone_stop(hi_void);

#ifdef __cplusplus
}
#endif

#endif // I2S_MICROPHONE_H

proxy_server.py

#!/usr/bin/env python3
"""
WebSocket to WSS Proxy Server
本地 WS 服务器转发到远程 WSS 服务器
"""

import asyncio
import websockets
import websockets.server
import logging
from datetime import datetime
import json
import signal
import sys
import time
import jwt
import base64

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.FileHandler('proxy_server.log', encoding='utf-8'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)


def generate_jwt_token(api_key, expire_seconds=600):
    """生成JWT Token用于鉴权"""
    api_key_parts = api_key.split('.')
    api_key_id = api_key_parts[0]
    api_secret = api_key_parts[1] if len(api_key_parts) > 1 else ""

    payload = {
        "api_key": api_key_id,
        "exp": int(time.time()) + expire_seconds,
        "timestamp": int(time.time() * 1000)
    }
    token = jwt.encode(payload, api_secret, algorithm="HS256", headers={"alg": "HS256", "sign_type": "SIGN"})
    return token


class WSProxyConfig:
    """代理服务器配置"""

    def __init__(self, config_file='config.json'):
        self.local_host = "0.0.0.0"
        self.local_port = 8080
        self.remote_wss_url = "wss://open.bigmodel.cn/api/paas/v4/realtime"
        self.api_key = "8dbddcd50b44fea252fc99fa881f9c3e.2m10cAx1oOdz26h7"  # 智谱AI的API Key
        self.reconnect_interval = 5
        self.ping_interval = 20
        self.ping_timeout = 10

        # 尝试从配置文件加载
        try:
            with open(config_file, 'r', encoding='utf-8') as f:
                config = json.load(f)
                self.__dict__.update(config)
                logger.info(f"已加载配置文件: {config_file}")
        except FileNotFoundError:
            logger.warning(f"配置文件不存在,使用默认配置: {config_file}")
        except Exception as e:
            logger.error(f"加载配置文件失败: {e}")


class WSProxyClient:
    """单个客户端代理连接"""

    def __init__(self, client_ws, config):
        self.client_ws = client_ws
        self.config = config
        self.server_ws = None
        self.client_addr = f"{client_ws.remote_address[0]}:{client_ws.remote_address[1]}"
        self.running = True

    async def connect_to_remote(self):
        """连接到远程 WSS 服务器"""
        try:
            logger.info(f"[{self.client_addr}] 正在连接远程服务器: {self.config.remote_wss_url}")

            # 构建JWT Token用于鉴权
            headers = {}
            if self.config.api_key:
                jwt_token = generate_jwt_token(self.config.api_key)
                headers['Authorization'] = f'Bearer {jwt_token}'
                logger.info(f"[{self.client_addr}] 使用JWT Token鉴权")
                logger.debug(f"[{self.client_addr}] JWT Token: {jwt_token[:50]}...")
            else:
                logger.warning(f"[{self.client_addr}] 未配置API Key,可能导致鉴权失败")

            # 使用 websockets.connect 并传递 additional_headers
            # 禁用ping/pong,因为GLM-Realtime服务器可能不期望客户端发送heartbeat
            self.server_ws = await websockets.connect(
                self.config.remote_wss_url,
                ping_interval=None,  # 禁用ping
                ping_timeout=None,   # 禁用ping timeout
                additional_headers=headers
            )
            logger.info(f"[{self.client_addr}] 成功连接到远程服务器")
            return True
        except Exception as e:
            logger.error(f"[{self.client_addr}] 连接远程服务器失败: {e}")
            return False

    async def forward_client_to_server(self):
        """客户端消息 -> 远程服务器"""
        logger.info(f"[{self.client_addr}] [客户端->服务器] 转发任务开始")
        try:
            async for message in self.client_ws:
                if not self.running:
                    logger.warning(f"[{self.client_addr}] [客户端->服务器] running=False,退出循环")
                    break
                if self.server_ws:
                    try:
                        # 打印消息内容用于调试
                        try:
                            msg_json = json.loads(message)
                            logger.info(f"[{self.client_addr}] 客户端 -> 服务器: {len(message)} bytes, type={msg_json.get('type', 'unknown')}")
                            # 如果是 session.update,打印完整内容
                            if msg_json.get('type') == 'session.update':
                                logger.info(f"[{self.client_addr}] session.update 内容: {json.dumps(msg_json, ensure_ascii=False, indent=2)}")
                        except:
                            # 记录非JSON数据的前100个字符,用于调试
                            truncated = message[:100] if len(message) > 100 else message
                            logger.info(f"[{self.client_addr}] 客户端 -> 服务器: {len(message)} bytes (非JSON): {truncated}")

                        await self.server_ws.send(message)
                    except Exception:
                        logger.warning(f"[{self.client_addr}] 远程连接未建立,消息丢弃")
                else:
                    logger.warning(f"[{self.client_addr}] 远程连接未建立,消息丢弃")
        except websockets.exceptions.ConnectionClosed as e:
            logger.info(f"[{self.client_addr}] [客户端->服务器] 客户端断开连接: code={e.code}, reason={e.reason}")
        except Exception as e:
            logger.error(f"[{self.client_addr}] [客户端->服务器] 转发客户端消息错误: {type(e).__name__}: {e}")
        finally:
            logger.info(f"[{self.client_addr}] [客户端->服务器] 转发任务结束")
            self.running = False

    async def forward_server_to_client(self):
        """远程服务器消息 -> 客户端"""
        logger.info(f"[{self.client_addr}] [服务器->客户端] 转发任务开始")
        message_count = 0
        skipped_audio_count = 0  # 跳过的音频消息计数

        try:
            async for message in self.server_ws:
                if not self.running:
                    logger.warning(f"[{self.client_addr}] [服务器->客户端] running=False,退出循环")
                    break

                message_count += 1

                # 检查消息大小,拦截超大音频包
                msg_size = len(message)

                # 如果消息超过 10KB,解析类型判断是否为音频包
                if msg_size > 10240:
                    try:
                        msg_json = json.loads(message)
                        msg_type = msg_json.get('type', 'unknown')

                        # 拦截大音频包,避免 Hi3861 内存不足
                        if msg_type == 'response.audio.delta':
                            skipped_audio_count += 1
                            logger.info(f"[{self.client_addr}] [服务器->客户端] 跳过大音频包: {msg_size} bytes, type={msg_type} (累计跳过: {skipped_audio_count})")
                            continue  # 不转发此消息
                        elif msg_type == 'response.audio.done':
                            logger.info(f"[{self.client_addr}] [服务器->客户端] 音频完成: {msg_size} bytes, type={msg_type}")
                        elif msg_type == 'response.text.delta':
                            # 文本消息通常很小,可以正常转发
                            logger.info(f"[{self.client_addr}] [服务器->客户端] 文本回复: {msg_size} bytes, type={msg_type}")
                        elif msg_type == 'response.text.done':
                            logger.info(f"[{self.client_addr}] [服务器->客户端] 文本完成: {msg_size} bytes, type={msg_type}")
                        else:
                            logger.info(f"[{self.client_addr}] [服务器->客户端] 大消息: {msg_size} bytes, type={msg_type}")
                    except Exception as e:
                        logger.warning(f"[{self.client_addr}] 解析大消息失败: {e}")

                # 打印消息内容用于调试
                try:
                    msg_json = json.loads(message)
                    msg_type = msg_json.get('type', 'unknown')

                    # 特殊处理错误消息
                    if msg_type == 'error':
                        logger.error(f"[{self.client_addr}] 服务器返回错误: {json.dumps(msg_json, ensure_ascii=False, indent=2)}")
                    elif msg_type in ['session.updated', 'session.created']:
                        logger.info(f"[{self.client_addr}] 服务器 -> 客户端: {len(message)} bytes, type={msg_type}")
                    elif msg_type in ['input_audio_buffer.speech_started', 'input_audio_buffer.speech_stopped',
                                      'input_audio_buffer.committed', 'response.created']:
                        logger.info(f"[{self.client_addr}] 服务器 -> 客户端: {len(message)} bytes, type={msg_type}")
                    elif msg_type == 'heartbeat':
                        logger.info(f"[{self.client_addr}] 服务器 -> 客户端: {len(message)} bytes, type={msg_type} (已处理 {message_count} 条消息)")
                    elif msg_type == 'response.text.delta':
                        # 提取文本内容
                        text_content = msg_json.get('text', '')
                        if text_content:
                            logger.info(f"[{self.client_addr}] [AI回复] {text_content[:100]}...")
                    elif msg_type == 'response.text.done':
                        logger.info(f"[{self.client_addr}] [服务器->客户端] {len(message)} bytes, type={msg_type}")
                    else:
                        # 其他小消息正常转发
                        if msg_size < 4096:  # 小于 4KB 的消息才转发
                            logger.info(f"[{self.client_addr}] 服务器 -> 客户端: {len(message)} bytes, type={msg_type}")
                except Exception as e:
                    logger.warning(f"[{self.client_addr}] 服务器 -> 客户端: {len(message)} bytes (非JSON), 解析错误: {e}")

                # 转发消息到客户端(如果是小消息)
                if msg_size < 4096:  # 只转发小于 4KB 的消息
                    await self.client_ws.send(message)
                else:
                    logger.info(f"[{self.client_addr}] [服务器->客户端] 跳过大消息: {msg_size} bytes")

        except websockets.exceptions.ConnectionClosed as e:
            logger.warning(f"[{self.client_addr}] [服务器->客户端] 远程服务器断开连接: code={e.code}, reason={e.reason}")
        except Exception as e:
            logger.error(f"[{self.client_addr}] [服务器->客户端] 转发服务器消息错误: {type(e).__name__}: {e}")
        finally:
            logger.info(f"[{self.client_addr}] [服务器->客户端] 转发任务结束,共处理 {message_count} 条消息,跳过 {skipped_audio_count} 个大音频包")
            self.running = False

    async def handle(self):
        """处理双向转发"""
        if not await self.connect_to_remote():
            await self.client_ws.close()
            return

        logger.info(f"[{self.client_addr}] 开始双向转发...")

        try:
            # 并行运行两个方向的转发
            results = await asyncio.gather(
                self.forward_client_to_server(),
                self.forward_server_to_client(),
                return_exceptions=True
            )

            # 检查是否有异常
            for i, result in enumerate(results):
                task_name = "客户端->服务器" if i == 0 else "服务器->客户端"
                if isinstance(result, Exception):
                    logger.warning(f"[{self.client_addr}] {task_name} 任务异常: {type(result).__name__}: {result}")
                else:
                    logger.info(f"[{self.client_addr}] {task_name} 任务正常结束")

            logger.info(f"[{self.client_addr}] 双向转发任务结束")
        except Exception as e:
            logger.error(f"[{self.client_addr}] 双向转发异常: {type(e).__name__}: {e}")
            import traceback
            logger.error(f"[{self.client_addr}] 异常堆栈:\n{''.join(traceback.format_tb(e.__traceback__))}")
        finally:
            await self.cleanup()

    async def cleanup(self):
        """清理资源"""
        self.running = False
        logger.info(f"[{self.client_addr}] 开始清理资源...")

        # 检查客户端连接状态
        try:
            if self.client_ws:
                logger.info(f"[{self.client_addr}] 关闭客户端连接...")
                await self.client_ws.close()
                logger.info(f"[{self.client_addr}] 客户端连接已关闭")
        except Exception as e:
            logger.error(f"[{self.client_addr}] 关闭客户端连接错误: {e}")

        # 检查服务器连接状态
        try:
            if self.server_ws:
                logger.info(f"[{self.client_addr}] 关闭服务器连接...")
                await self.server_ws.close()
                logger.info(f"[{self.client_addr}] 服务器连接已关闭")
        except Exception as e:
            logger.error(f"[{self.client_addr}] 关闭服务器连接错误: {e}")

        logger.info(f"[{self.client_addr}] 连接已关闭")


class WSProxyServer:
    """WebSocket 代理服务器"""

    def __init__(self, config):
        self.config = config
        self.server = None
        self.clients = []
        self.shutdown_event = asyncio.Event()

    async def handle_client(self, client_ws):
        """处理新的客户端连接"""
        client_addr = f"{client_ws.remote_address[0]}:{client_ws.remote_address[1]}"
        logger.info(f"新客户端连接: {client_addr}")

        proxy_client = WSProxyClient(client_ws, self.config)
        self.clients.append(proxy_client)

        try:
            await proxy_client.handle()
        finally:
            if proxy_client in self.clients:
                self.clients.remove(proxy_client)

    async def start(self):
        """启动代理服务器"""
        logger.info(f"启动代理服务器: ws://{self.config.local_host}:{self.config.local_port}")
        logger.info(f"转发目标: {self.config.remote_wss_url}")

        self.server = await websockets.server.serve(
            self.handle_client,
            self.config.local_host,
            self.config.local_port,
            ping_interval=self.config.ping_interval,
            ping_timeout=self.config.ping_timeout
        )

        logger.info("代理服务器已启动,等待连接...")
        await self.shutdown_event.wait()

    async def stop(self):
        """停止代理服务器"""
        logger.info("正在停止代理服务器...")
        self.shutdown_event.set()

        if self.server:
            self.server.close()
            await self.server.wait_closed()

        # 关闭所有客户端连接
        for client in self.clients:
            await client.cleanup()

        logger.info("代理服务器已停止")


def signal_handler(signum, frame):
    """信号处理器"""
    logger.info(f"收到信号 {signum},准备退出...")
    sys.exit(0)


async def main():
    """主函数"""
    # 注册信号处理
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    # 加载配置
    config = WSProxyConfig()

    # 创建并启动服务器
    server = WSProxyServer(config)

    try:
        await server.start()
    except KeyboardInterrupt:
        logger.info("收到中断信号")
    finally:
        await server.stop()


if __name__ == "__main__":
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        logger.info("程序已退出")

问答