/* * KTCPVS An implementation of the TCP Virtual Server daemon inside * kernel for the LINUX operating system. KTCPVS can be used * to build a moderately scalable and highly available server * based on a cluster of servers, with more flexibility. * * Version: $Id$ * * Authors: Wensong Zhang * * This program is free software; you can redistribute it and/or * modify it under the terms of the GNU General Public License * as published by the Free Software Foundation; either version * 2 of the License, or (at your option) any later version. * */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include "tcp_vs.h" int StartListening(struct tcp_vs *vs) { struct socket *sock; struct sockaddr_in sin; int error; EnterFunction("StartListening"); /* First create a socket */ error = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock); if (error<0) TCP_VS_ERR("Error during creation of socket; terminating\n"); /* Now bind the socket */ sin.sin_family = AF_INET; sin.sin_addr.s_addr = INADDR_ANY; sin.sin_port = htons((unsigned short)vs->serverport); error = sock->ops->bind(sock, (struct sockaddr*)&sin, sizeof(sin)); if (error<0) { TCP_VS_ERR("Error binding socket. This means that some other " "daemon is (or was a short time ago) using port %i.\n", vs->serverport); return 0; } /* Grrr... setsockopt() does this. */ sock->sk->reuse = 1; /* Now, start listening on the socket */ /* I have no idea what a sane backlog-value is. 48 works so far. */ error=sock->ops->listen(sock, 48); if (error!=0) (void)printk(KERN_ERR "ktcpvs: Error listening on socket \n"); vs->mainsock = sock; LeaveFunction("StartListening"); return 1; } void StopListening(struct tcp_vs *vs) { struct socket *sock; EnterFunction("StopListening"); if (vs->mainsock == NULL) return; sock=vs->mainsock; vs->mainsock = NULL; sock_release(sock); LeaveFunction("StopListening"); } struct socket * tcp_vs_connect2dest(struct tcp_vs_dest *dest) { struct socket *sock; struct sockaddr_in sin; int error; EnterFunction("tcp_vs_connect2dest"); /* First create a socket */ error = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock); if (error<0) TCP_VS_ERR("Error during creation of socket; terminating\n"); /* Now connect to the destination server */ sin.sin_family = AF_INET; sin.sin_addr.s_addr = dest->addr; sin.sin_port = dest->port; error = sock->ops->connect(sock, (struct sockaddr*)&sin, sizeof(sin), 0); if (error<0) { TCP_VS_ERR("Error connecting to the remote host\n"); return NULL; } /* Grrr... setsockopt() does this. */ sock->sk->reuse = 1; LeaveFunction("tcp_vs_connect2dest"); return sock; } /* * tcp_vs_sendbuffer and tcp_vs_sendbuffer_async are to send bytes from * the buffer to the socket. * A positive return-value indicates the number of bytes sent, a negative * value indicates an error-condition. */ int tcp_vs_sendbuffer(struct socket *sock, const char *buffer,const size_t length) { struct msghdr msg; mm_segment_t oldfs; struct iovec iov; int len; EnterFunction("tcp_vs_sendbuffer"); msg.msg_name = 0; msg.msg_namelen = 0; msg.msg_iov = &iov; msg.msg_iovlen = 1; msg.msg_control = NULL; msg.msg_controllen = 0; msg.msg_flags = MSG_NOSIGNAL; msg.msg_iov->iov_len = (__kernel_size_t)length; msg.msg_iov->iov_base = (char*) buffer; oldfs = get_fs(); set_fs(KERNEL_DS); len = sock_sendmsg(sock, &msg, length); set_fs(oldfs); LeaveFunction("tcp_vs_sendbuffer"); return len; } int tcp_vs_sendbuffer_async(struct socket *sock, const char *buffer,const size_t length) { struct msghdr msg; mm_segment_t oldfs; struct iovec iov; int len; EnterFunction("tcp_vs_sendbuffer_async"); msg.msg_name = 0; msg.msg_namelen = 0; msg.msg_iov = &iov; msg.msg_iovlen = 1; msg.msg_control = NULL; msg.msg_controllen = 0; msg.msg_flags = MSG_DONTWAIT|MSG_NOSIGNAL; msg.msg_iov->iov_base = (void *)buffer; msg.msg_iov->iov_len = length; oldfs = get_fs(); set_fs(KERNEL_DS); len = sock_sendmsg(sock, &msg, (size_t)(length)); set_fs(oldfs); LeaveFunction("tcp_vs_sendbuffer_async"); return len; } int tcp_vs_recvbuffer(struct socket *sock, char *buffer, const size_t buflen) { struct msghdr msg; struct iovec iov; int len; mm_segment_t oldfs; EnterFunction("tcp_vs_recvbuffer"); /* Receive a packet */ msg.msg_name = 0; msg.msg_namelen = 0; msg.msg_iov = &iov; msg.msg_iovlen = 1; msg.msg_control = NULL; msg.msg_controllen = 0; msg.msg_flags = 0; iov.iov_base = buffer; iov.iov_len = (size_t)buflen-1; oldfs = get_fs(); set_fs(KERNEL_DS); len = sock_recvmsg(sock, &msg, buflen-1, 0); set_fs(oldfs); if (len < 0) return -1; #ifdef CONFIG_TCP_VS_DEBUG if (len > 200) { char str[201]; strncpy(str, buffer, 200); str[200] = 0; TCP_VS_DBG("len: %d str: %s\n", len, str); } else { buffer[len]=0; TCP_VS_DBG("len: %d str: %s\n", len, buffer); } #endif LeaveFunction("tcp_vs_recvbuffer"); return len; } int tcp_vs_recvbuffer_async(struct socket *sock, char *buffer, const size_t buflen) { struct msghdr msg; struct iovec iov; int len; mm_segment_t oldfs; EnterFunction("tcp_vs_recvbuffer_async"); /* Receive a packet */ msg.msg_name = 0; msg.msg_namelen = 0; msg.msg_iov = &iov; msg.msg_iovlen = 1; msg.msg_control = NULL; msg.msg_controllen = 0; msg.msg_flags = 0; iov.iov_base = buffer; iov.iov_len = (size_t)buflen-1; oldfs = get_fs(); set_fs(KERNEL_DS); len = sock_recvmsg(sock, &msg, buflen-1, MSG_DONTWAIT); /* MSG_PEEK */ set_fs(oldfs); if (len < 0) return -1; #ifdef CONFIG_TCP_VS_DEBUG if (len > 200) { char str[201]; strncpy(str, buffer, 200); str[200] = 0; TCP_VS_DBG("len: %d str: %s\n", len, str); } else { buffer[len]=0; TCP_VS_DBG("len: %d str: %s\n", len, buffer); } #endif LeaveFunction("tcp_vs_recvbuffer_async"); return len; } char *tcp_vs_getline(char *s, char *token, int n) { int i; token[0] = '\0'; if (s == NULL) return NULL; if (*s == '\0') return NULL; while (*s == '\n') s++; i = 0; while (*s != '\0' && *s != '\n') { if (i < n-1) { token[i] = *s; i++; } s++; } token[i] = '\0'; return s; } char *tcp_vs_getword(char *s, char *token, int n) { int i; token[0] = '\0'; if (s == NULL) return NULL; while (isspace(*s)) s++; if (*s == '\0' || *s == '\n') return NULL; i = 0; while (*s != '\0' && *s != '\n' && !isspace(*s)) { if (i < n-1) { token[i] = *s; i++; } s++; } token[i] = '\0'; return s; }