#include "sock_msg.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <fcntl.h>
#include <errno.h>
#include <unistd.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/epoll.h>
#include <sys/un.h>
#include <sys/types.h>
#include <sys/msg.h>
#include <mqueue.h>
#include <sys/ipc.h>

Message            messageQueue [ MAX_MESSAGES ];    // 消息队列
int                messageCount = 0;                 // 消息队列中消息数量
int                msgid        = -1;                // 全局变量，用于存储消息队列ID

// socket信息结构体
typedef struct
{
    int fd;
    sock_callback_t callback;
    uint32_t mode;
} socket_info_t;

// 全局数组来保存socket信息
socket_info_t sockets[MAX_CLIENTS + 1] = {0};

// 消息通道路径数组
const char *msg_channel_paths [ MSG_CHAN_MAX ] = {
    "/tmp/msg_channel_app",
    "/tmp/msg_channel_lim",
    "/tmp/msg_channel_iomod",
};

int sock_init(int server, uint32_t ipaddr, uint32_t port)
{
    int                sockfd;
    struct sockaddr_in serv_addr;

    // 创建socket
    sockfd = socket(AF_INET, SOCK_STREAM, 0);
    if ( sockfd < 0 )
    {
        // 错误处理
        perror("socket error");
        return -1;
    }

    // 如果是服务器模式
    if ( server )
    {
        // 初始化服务器地址结构，绑定IP地址和端口
        bzero(( char * )&serv_addr, sizeof(serv_addr));    // 清零结构体
        serv_addr.sin_family      = AF_INET;               // 设置服务器地址结构
        serv_addr.sin_addr.s_addr = ipaddr;                // 设置IP地址
        serv_addr.sin_port        = htons(port);           // 设置端口

        if ( bind(sockfd, ( struct sockaddr * )&serv_addr, sizeof(serv_addr)) < 0 )
        {
            // 错误处理
            perror("bind error");
            return -1;
        }

        // 监听
        if ( listen(sockfd, 20) < 0 )
        {
            // 错误处理
            perror("listen error");
            return -1;
        }
    }
    // 客户端模式
    else
    {
        // 连接
        bzero(( char * )&serv_addr, sizeof(serv_addr));    // 清零结构体
        serv_addr.sin_family      = AF_INET;
        serv_addr.sin_addr.s_addr = ipaddr;         // 设置IP地址
        serv_addr.sin_port        = htons(port);    // 设置端口

        if ( connect(sockfd, ( struct sockaddr * )&serv_addr, sizeof(serv_addr)) < 0 )
        {
            // 错误处理
            perror("connect error");
            return -1;
        }
    }

    return sockfd;
}

void sock_exit(int fd)
{
    close(fd);
}

int sock_write(int fd, const uint8_t *buf, int len)
{
    int bytes_sent = 0;    // 记录已发送的字节数
    int n;                 // 记录每次发送的字节数

    while ( bytes_sent < len )
    {
        // 使用write函数发送数据，返回值是实际写入的字节数
        n = write(fd, buf + bytes_sent, len - bytes_sent);
        if ( n < 0 )
        {
            // 如果write函数返回-1，表示发送失败
            if ( errno == EINTR )
            {
                // 如果是因为被信号中断，则继续发送剩余的数据
                perror("Interrupted system call, write again");
                continue;
            }
            else
            {
                // 其他错误，返回-1表示失败
                perror("Other errors, fail to send");
                return -1;
            }
        }
        else if ( n == 0 )
        {
            // 如果write函数返回0，表示对端已经关闭连接，发送失败
            perror("Peer closed connection");
            return -1;
        }
        else
        {
            // 更新已发送的字节数
            bytes_sent += n;
        }
    }

    return bytes_sent;
}

int sock_read(int fd, uint8_t *buf, int len)
{
    int bytes_received = 0;    // 记录已接收的字节数
    int n;

    while ( bytes_received < len )
    {
        // 使用read函数接收数据，返回值是实际读入的字节数
        n = read(fd, buf + bytes_received, len - bytes_received);
        if ( n < 0 )
        {
            // 如果read函数返回-1，表示接收失败
            if ( errno == EINTR )
            {
                // 如果是因为被信号中断，则继续接收剩余的数据
                perror("Interrupted system call, read again");
                continue;
            }
            else
            {
                // 其他错误，返回-1表示失败
                perror("Other errors, fail to receive");
                return -1;
            }
        }
        else if ( n == 0 )
        {
            // 如果read函数返回0，表示对端已经关闭连接，接收完成
            printf("Peer closed connection\n");
            break;
        }
        else
        {
            // 更新已接收的字节数
            bytes_received += n;
        }
    }

    return bytes_received;
}

int sock_info(int fd, uint32_t *ipaddr, uint32_t *port)
{
    struct sockaddr_in peer_addr;                            // 对端地址结构体
    socklen_t          peer_addr_len = sizeof(peer_addr);    // 对端地址结构体长度

    // 使用getpeername函数获取对端地址信息
    if ( getpeername(fd, ( struct sockaddr * )&peer_addr, &peer_addr_len) == -1 )
    {
        // 如果getpeername函数返回-1，表示获取失败
        perror("getpeername error");
        return -1;
    }

    // 从对端地址结构体中获取IP地址和端口号
    *ipaddr = ntohl(peer_addr.sin_addr.s_addr);
    *port   = ntohs(peer_addr.sin_port);

    return 0;
}

int sock_ev_register(int fd, sock_callback_t cb, uint32_t mode)
{
    for (int i = 0; i < MAX_CLIENTS + 1; i++)
    {
        if (sockets[i].fd == 0)
        {
            sockets[i].fd = fd;
            sockets[i].callback = cb;
            sockets[i].mode = mode;
            return 0; // 成功
        }
    }
    return -1; // 失败，没有空位
}

int sock_ev_unregister(int fd)
{
    for (int i = 0; i < MAX_CLIENTS + 1; i++)
    {
        if (sockets[i].fd == fd)
        {
            sockets[i].fd = 0;
            sockets[i].callback = NULL;
            sockets[i].mode = 0;
            return 0; // 成功
        }
    }
    return -1; // 失败，未找到对应的socket
}

int init_msg_queue( )
{
    key_t key;
    key = ftok(".", 10);
    if ( key == -1 )
    {
        perror("ftok(): ");
        return -1;
    }

    msgid = msgget(key, IPC_CREAT | 0666);
    if ( msgid == -1 )
    {
        perror("msgget(): ");
        return -1;
    }

    return 0;
}

int destroy_msg_queue( )
{
    if ( msgid != -1 )
    {
        if ( msgctl(msgid, IPC_RMID, NULL) == -1 )
        {
            perror("msgctl(): ");
            return -1;
        }
        msgid = -1;
    }
    return 0;
}

int msg_local_recv(msg_data_t *msg)
{
    if ( msgid == -1 )
    {
        fprintf(stderr, "Message queue not initialized.\n");
        return -1;
    }

    printf("msgid: %d\n", msgid);
    ssize_t rbytes = msgrcv(msgid, ( void * )msg, sizeof(msg_data_t), 0, 0);
    if ( rbytes == -1 )
    {
        perror("[ERROR] msgrcv(): ");
        return -1;
    }

    return ( int )rbytes;
}

int msg_local_send(uint16_t type, uint32_t code, void *pad, int len)
{
    if ( msgid == -1 )
    {
        fprintf(stderr, "Message queue not initialized.\n");
        return -1;
    }

    msg_data_t msg;
    msg.type = type;
    msg.code = code;
    memcpy(msg.pad, pad, len);
    msg.len = len;

    if ( msgsnd(msgid, ( const void * )&msg, sizeof(msg), 0) == -1 )
    {
        perror("msgsnd(): ");
        return -1;
    }

    return 0;
}

int msg_local_recv_pulse(uint16_t *type, uint32_t *code)
{
    if ( messageCount > 0 )
    {
        *type = messageQueue [ 0 ].type;
        *code = messageQueue [ 0 ].code;
        // 移动队列
        for ( int i = 0; i < messageCount - 1; i++ )
        {
            messageQueue [ i ] = messageQueue [ i + 1 ];
        }
        messageCount--;
        return 0;    // 成功
    }
    else
    {
        errno = EAGAIN;    // 队列空
        return -1;         // 失败
    }
}

int msg_local_send_pulse(uint16_t type, uint32_t code)
{
    if ( messageCount < MAX_MESSAGES )
    {
        messageQueue [ messageCount ].type = type;
        messageQueue [ messageCount ].code = code;
        messageCount++;
        return 0;    // 成功
    }
    else
    {
        errno = EAGAIN;    // 队列满
        return -1;         // 失败
    }
}

int msg_recv(int channel, msg_data_t *msg)
{
    if ( channel < 0 || channel >= MSG_CHAN_MAX || msg == NULL )
    {
        return -1;
    }

    int fd = open(msg_channel_paths [ channel ], O_RDONLY);
    if ( fd == -1 )
    {
        perror("msg_recv: open");
        return -1;
    }

    int bytes_read = read(fd, msg, sizeof(msg_data_t));
    close(fd);

    return bytes_read;
}

int msg_send(int channel, uint16_t type, uint32_t code, void *pad, int len)
{
    if ( channel < 0 || channel >= MSG_CHAN_MAX || (len > 0 && pad == NULL) )
    {
        return -1;
    }

    msg_data_t msg = {
        .type = type,
        .code = code,
        .len  = len};
    memcpy((void *)(msg.pad), pad, len);

    int fd = open(msg_channel_paths [ channel ], O_WRONLY);
    if ( fd == -1 )
    {
        perror("msg_send: open");
        return -1;
    }

    int bytes_written = write(fd, &msg, sizeof(msg_data_t));
    close(fd);

    return bytes_written == sizeof(msg_data_t) ? 0 : -1;
}