/* IMSpector - Instant Messenger Transparent Proxy Service
 * http://www.imspector.org/
 * (c) Lawrence Manning <lawrence@aslak.net>, 2006
 * 
 * Contributions from:
 *     Ryan Wagoner <ryan@wgnrs.dynu.com>, 2006
 *     Duane Wessels, Squid 2.6-STABLE5, src/client_side.c, clientNatLookup()
 *     Simon Brassington <simon@the-brassingtons.co.uk>, 2007
 *
 * Released under the GPL v2. */

#include "imspector.h"

#define SOCK_SIZE(domain) ((domain) == AF_INET ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_un))

Socket::Socket(int domainin, int typein)
{
	domain = domainin;
	type = typein;
	fd = -1;
#ifdef HAVE_SSL
	ssl = NULL;
	peercert = NULL;
#endif
}

Socket::~Socket()
{
	if (fd != -1)
		close(fd);
}

bool Socket::listensocket(std::string localaddress)
{
	if ((fd = socket(domain, type, 0)) < 0)
	{
		syslog(LOG_ERR, "Listen socket, socket() failed");
		return false;
	}
	
	struct mysockaddr localname = stringtosockaddr(localaddress);
	
	if (domain == AF_INET)
	{		
		int i = 1;
		setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &i, sizeof(i));
	}
	else
		unlink(localaddress.c_str());

	if (bind(fd, (struct sockaddr *) &localname, SOCK_SIZE(domain)) < 0)
	{
		syslog(LOG_ERR, "Listen socket, bind() failed");
		close(fd);
		return false;
	}
	
	if (listen(fd, 5) < 0)
	{
		syslog(LOG_ERR, "Listen socket, listen() failed");
		close(fd);
		return false;
	}
	
	return true;
}

bool Socket::awaitconnection(class Socket &clientsocket, std::string &clientaddress)
{
	int newfd;
	struct sockaddr_in clientsockaddr;
	socklen_t clientsockaddrlen = sizeof(struct sockaddr_in);

	if ((newfd = accept(fd, (struct sockaddr *) &clientsockaddr, &clientsockaddrlen)) < 0)
		return false;
	
	clientsocket.setfd(newfd);
	
	clientaddress = sockaddrtostring((struct mysockaddr *) &clientsockaddr);
	
	return true;
}

std::string Socket::getredirectaddress(void)
{
	struct sockaddr_in redirectsockaddr;
	socklen_t redirectsockaddrlen = sizeof(struct sockaddr_in);

#if LINUX_NETFILTER
	if (getsockopt(fd, SOL_IP, SO_ORIGINAL_DST, &redirectsockaddr, &redirectsockaddrlen) < 0)
	{
		syslog(LOG_ERR, "Redirect address, getsockopt() failed");
		return "";
	}
	else
		return sockaddrtostring((struct mysockaddr *) &redirectsockaddr);		
}
#elif IPFW_TRANSPARENT
	if (getsockname(fd, (struct sockaddr *) &redirectsockaddr, (socklen_t *) &redirectsockaddrlen) < 0)
	{
		syslog(LOG_ERR, "Redirect address, getsockname() failed");
		return "";
	}
	else
		return sockaddrtostring((struct mysockaddr *) &redirectsockaddr);		
}
#elif PF_TRANSPARENT
	struct sockaddr_in clientsockaddr;
	socklen_t clientsockaddrlen = sizeof(struct sockaddr_in);
	
	if (getpeername(fd, (struct sockaddr*) &clientsockaddr, &clientsockaddrlen) < 0)
	{
		syslog(LOG_ERR, "Redirect address, getpeername() failed");
		return "";
	}
	
	if (getsockname(fd, (struct sockaddr*) &redirectsockaddr, &redirectsockaddrlen) < 0)
	{
		syslog(LOG_ERR, "Redirect address, getsockname() failed");
		return "";
	}

	int pffd;
	
	if ((pffd = open("/dev/pf", O_RDWR)) < 0)
	{
		syslog(LOG_ERR, "Redirect address, PF (/dev/pf) open failed: %s", strerror(errno));
		syslog(LOG_NOTICE, "Check permissions on /dev/pf. IMSpector needs read/write privileges.");
		return "";
	}

	struct pfioc_natlook nl;
	
	memset(&nl, 0, sizeof(struct pfioc_natlook));
	
	nl.saddr.v4.s_addr = clientsockaddr.sin_addr.s_addr;
	nl.sport = clientsockaddr.sin_port;
	nl.daddr.v4.s_addr = redirectsockaddr.sin_addr.s_addr;
	nl.dport = redirectsockaddr.sin_port;
	nl.af = AF_INET;
	nl.proto = IPPROTO_TCP;
	nl.direction = PF_OUT;
	
	if (ioctl(pffd, DIOCNATLOOK, &nl) < 0) 
	{
		close(pffd);
		syslog(LOG_ERR, "Redirect address, PF lookup failed");
		return "";
	}
	else
	{
		close(pffd);
		
		redirectsockaddr.sin_port = nl.rdport;
		redirectsockaddr.sin_addr = nl.rdaddr.v4;
	
		return sockaddrtostring((struct mysockaddr *) &redirectsockaddr);
	}		
}
#elif IPF_TRANSPARENT
	struct sockaddr_in clientsockaddr;
	socklen_t clientsockaddrlen = sizeof(struct sockaddr_in);
	
	if (getpeername(fd, (struct sockaddr*) &clientsockaddr, &clientsockaddrlen) < 0)
	{
		syslog(LOG_ERR, "Redirect address, getpeername() failed");
		return "";
	}
	
	if (getsockname(fd, (struct sockaddr*) &redirectsockaddr, &redirectsockaddrlen) < 0)
	{
		syslog(LOG_ERR, "Redirect address, getsockname() failed");
		return "";
	}
		
	struct natlookup natLookup;
	static int natfd;
	int x;

#if defined(IPFILTER_VERSION) && (IPFILTER_VERSION >= 4000027)
	struct ipfobj obj;
#else
	static int siocgnatl_cmd = SIOCGNATL & 0xff;
#endif

#if defined(IPFILTER_VERSION) && (IPFILTER_VERSION >= 4000027)
	obj.ipfo_rev = IPFILTER_VERSION;
	obj.ipfo_size = sizeof(natLookup);
	obj.ipfo_ptr = &natLookup;
	obj.ipfo_type = IPFOBJ_NATLOOKUP;
	obj.ipfo_offset = 0;
#endif

	natLookup.nl_inip = clientsockaddr.sin_addr;
	natLookup.nl_inport = clientsockaddr.sin_port;
	natLookup.nl_outip = redirectsockaddr.sin_addr;
	natLookup.nl_outport = redirectsockaddr.sin_port;
	natLookup.nl_flags = IPN_TCP;

#ifdef IPNAT_NAME
	natfd = open(IPNAT_NAME, O_RDONLY, 0);
#else
	natfd = open(IPL_NAT, O_RDONLY, 0);
#endif

	if (natfd < 0) 
	{	 
#ifdef IPNAT_NAME
		syslog(LOG_ERR, "Redirect address, IP-Filter (%s) open failed: %s", IPNAT_NAME, strerror(errno));
		syslog(LOG_NOTICE, "Check permissions on %s. IMSpector needs read privileges.", IPNAT_NAME);
#else
		syslog(LOG_ERR, "Redirect address, IP-Filter (%s) open failed: %s", IPL_NAT, strerror(errno));
		syslog(LOG_NOTICE, "Check permissions on %s. IMSpector needs read privileges.", IPL_NAT);
#endif
	return "";
	}

#if defined(IPFILTER_VERSION) && (IPFILTER_VERSION >= 4000027)
	x = ioctl(natfd, SIOCGNATL, &obj);
#else
	/* IP-Filter changed the type for SIOCGNATL between
	 * 3.3 and 3.4.  It also changed the cmd value for
	 * SIOCGNATL, so at least we can detect it.  We could
	 * put something in configure and use ifdefs here, but
	 * this seems simpler. */
	if (63 == siocgnatl_cmd) 
	{
		struct natlookup *nlp = &natLookup;
		x = ioctl(natfd, SIOCGNATL, &nlp);
	} 
	else 
		x = ioctl(natfd, SIOCGNATL, &natLookup);
#endif

	if (x < 0)
	{
		close(natfd);
		syslog(LOG_ERR, "Redirect address, IP-Filter lookup failed");
		return "";
	} 
	else 
	{
		close(natfd);

		redirectsockaddr.sin_port = natLookup.nl_realport;
		redirectsockaddr.sin_addr = natLookup.nl_realip;

		return sockaddrtostring((struct mysockaddr *) &redirectsockaddr);
	}
}
#else
#warning "Don't know how to lookup the redirect address; redirect not available"
}
#endif

bool Socket::connectsocket(std::string remoteaddress, std::string interface)
{
	if ((fd = socket(domain, type, 0)) < 0)
	{
		syslog(LOG_ERR, "Connect socket, socket() failed");
		return false;
	}

	struct mysockaddr remotename = stringtosockaddr(remoteaddress);

#if defined(SOL_SOCKET) && defined(SO_BINDTODEVICE)
	if (!interface.empty())
	{
		int interface_len = interface.length() + 1;
		if (setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, interface.c_str(), interface_len) != 0)
		{
			syslog(LOG_ERR, "Connect socket, setsockopt() failed");
			return false;
		}
	}
#endif

	if (connect(fd, (struct sockaddr *) &remotename, SOCK_SIZE(domain)) < 0)
	{
		syslog(LOG_ERR, "Connect socket, connect() failed to %s", remoteaddress.c_str());
		return false;
	}

	return true;
}

#ifdef HAVE_SSL
bool Socket::enablessl(SSL_CTX *ctx)
{	
	ssl = SSL_new(ctx);
	
	if (!ssl)
	{
		syslog(LOG_ERR, "SSL new error: %s", ERR_error_string(ERR_get_error(), NULL));
		return false;
	}
	
	SSL_set_mode(ssl, SSL_MODE_AUTO_RETRY);

	return true;
}

bool Socket::sslaccept(void)
{
	if (ssl)
	{
		SSL_set_fd(ssl, fd);
	
		if (SSL_accept(ssl) < 0)
		{
			syslog(LOG_DEBUG, "SSL accept warning: %s", ERR_error_string(ERR_get_error(), NULL));
			return false;
		}
	}
	
	return true;
}

bool Socket::sslconnect(void)
{
	if (ssl)
	{
		SSL_set_fd(ssl, fd);
	
		if (SSL_connect(ssl) < 0)
		{
			syslog(LOG_DEBUG, "SSL connect warning: %s", ERR_error_string(ERR_get_error(), NULL));
			return false;
		}
	}

	peercert = SSL_get_peer_certificate(ssl);

	if (!peercert)
	{
		syslog(LOG_ERR, "SSL get peer certificate error: %s", ERR_error_string(ERR_get_error(), NULL));
		return "";
	}
	
	return true;
}

std::string Socket::getpeercommonname(void)
{
	X509_NAME *subject = X509_get_subject_name(peercert);

	if (!subject)
	{
		syslog(LOG_ERR, "X509 get subject name error: %s", ERR_error_string(ERR_get_error(), NULL));
		return "";
	}
	
	X509_NAME_ENTRY *entry = X509_NAME_get_entry(subject,
		X509_NAME_get_index_by_NID(subject, NID_commonName, -1));
		
	if (!entry)
	{
		syslog(LOG_ERR, "X509 NAME get entry error: %s", ERR_error_string(ERR_get_error(), NULL));
		return "";
	}
	
	char *commonname = (char *)ASN1_STRING_data(X509_NAME_ENTRY_get_data(entry));
	
	return (std::string)commonname;
}

int Socket::getvalidatecertresult(void)
{
	return (int) SSL_get_verify_result(ssl);
}

#endif

/* Returns the amount of data sent. */
int Socket::senddata(const char *buffer, int length)
{
#ifdef HAVE_SSL
	if (!ssl)
		return (send(fd, buffer, length, 0));
	else
		return (SSL_write(ssl, buffer, length));
#else
	return (send(fd, buffer, length, 0));
#endif
}

/* Sends all the data pointed to buffer down the fd. */
bool Socket::sendalldata(const char *buffer, int length)
{
	int totalsented = 0;
	int sented = 0;

	while (totalsented < length)
	{
		if ((sented = senddata(buffer + totalsented, length - totalsented)) < 1)
			return false;
		totalsented += sented;
	}
	
	return true;
}

/* Sends a line of text.  Borrows sendalldata. */
int Socket::sendline(const char *string, int length)
{
	return sendalldata(string, length) ? length : -1;
}

/* Returns the number of bytes received. */
int Socket::recvdata(char *buffer, int length)
{
#ifdef HAVE_SSL
	if (!ssl)
		return (recv(fd, buffer, length, 0));
	else
		return (SSL_read(ssl, buffer, length));
#else
	return (recv(fd, buffer, length, 0));
#endif
}

/* Receives all the data pointed to buffer from the fd. Will wait if
 * data not available. */
bool Socket::recvalldata(char *buffer, int length)
{
	int totalrecved = 0;
	int recved = 0;
	
	while (totalrecved < length)
	{
		if ((recved = recvdata(buffer + totalrecved, length - totalrecved)) < 1)
			return false;
		totalrecved += recved;
	}
	
	return true;
}

/* Gets a line of text from a socket. This is really better done using a buffering
 * system. */
int Socket::recvline(char *string, int length)
{
	int totalreceved = 0;
	int receved = 0;
	
	while (totalreceved < length)
	{
		if ((receved = recvdata(&string[totalreceved], 1)) < 1) return -1;
		if (string[totalreceved] == '\n') return totalreceved + 1;

		totalreceved += receved;
	}
	
	/* It was too long, but nevermind. */
	return totalreceved;
}

int Socket::getfd(void)
{
	return fd;
}

void Socket::closesocket(void)
{
#ifdef HAVE_SSL
	if (ssl)
	{
		SSL_free(ssl);
		ssl = NULL;
	}
	if (peercert)
	{
		X509_free(peercert);
		peercert = NULL;
	}
#endif
	if (fd != -1)
	{
		close(fd);
		fd = -1;
	}	
}

/* Private functions here. */

void Socket::setfd(int fdin)
{
	fd = fdin;
}

struct mysockaddr Socket::stringtosockaddr(std::string address)
{
	struct sockaddr_in myname_in;
	struct sockaddr_un myname_un;
	struct mysockaddr myname;
	
	memset(&myname_in, 0, sizeof(struct sockaddr_in));
	memset(&myname_un, 0, sizeof(struct sockaddr_un));
	memset(&myname, 0, sizeof(struct mysockaddr));

	if (domain == AF_INET)
	{
		in_addr_t ip = INADDR_ANY; uint16_t port = 0;
		char buffer[STRING_SIZE]; char *tmp;
		
		strncpy(buffer, address.c_str(), STRING_SIZE);
		tmp = strchr(buffer, ':');
		if (tmp) { *tmp = '\0'; port = atol(tmp + 1); }
		ip = inet_addr(buffer);

		/* Resolve the name, if needed. */
		if (ip == INADDR_NONE)
		{
			struct hostent *hostent = gethostbyname(buffer);
			if (hostent) memcpy((char *) &ip, hostent->h_addr, sizeof(in_addr_t));
		}
		
		myname_in.sin_family = domain;
		myname_in.sin_port = htons(port);
		myname_in.sin_addr.s_addr = ip;
		
		memcpy(&myname, &myname_in, sizeof(sockaddr_in));
	}
	else
	{
		myname_un.sun_family = domain;
		strncpy(myname_un.sun_path, address.c_str(), UNIX_PATH_MAX);
		
		memcpy(&myname, &myname_un, sizeof(sockaddr_un));
	}

	return myname;
}

std::string Socket::sockaddrtostring(struct mysockaddr *pmyname)
{
	struct sockaddr_in myname_in;
	struct sockaddr_un myname_un;
	std::string result;
	
	memset(&myname_in, 0, sizeof(struct sockaddr_in));
	memset(&myname_un, 0, sizeof(struct sockaddr_un));
	
	if (pmyname->sa_family == AF_INET)
	{
		memcpy(&myname_in, pmyname, sizeof(struct sockaddr_in));
		
		result = stringprintf("%s:%d", inet_ntoa(myname_in.sin_addr), ntohs(myname_in.sin_port));
	}
	else
	{
		memcpy(&myname_un, pmyname, sizeof(struct sockaddr_un));
		
		result = myname_un.sun_path;
	}
	
	return result;
}
