/* 
   Unix SMB/CIFS implementation.

   multiple interface handling

   Copyright (C) Andrew Tridgell 1992-2005
   
   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 3 of the License, or
   (at your option) any later version.
   
   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.
   
   You should have received a copy of the GNU General Public License
   along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "libmapi/libmapi.h"
#include "libmapi/libmapi_private.h"
#include <param.h>

#define ALLONES  ((uint32_t)0xFFFFFFFF)
/*
  address construction based on a patch from fred@datalync.com
*/
#define MKBCADDR(_IP, _NM) ((_IP & _NM) | (_NM ^ ALLONES))
#define MKNETADDR(_IP, _NM) (_IP & _NM)

bool is_zero_ip_v4(struct in_addr ip);
bool same_net_v4(struct in_addr ip1, struct in_addr ip2, struct in_addr mask);
uint32_t interpret_addr(const char *str);
struct in_addr interpret_addr2(const char *str);

/****************************************************************************
Try and find an interface that matches an ip. If we cannot, return NULL
  **************************************************************************/
static struct interface *iface_find(struct interface *interfaces, 
				    struct in_addr ip, bool CheckMask)
{
	struct interface *i;
	if (is_zero_ip_v4(ip)) return interfaces;

	for (i=interfaces;i;i=i->next)
		if (CheckMask) {
			if (same_net_v4(i->ip,ip,i->nmask)) return i;
		} else if (i->ip.s_addr == ip.s_addr) return i;

	return NULL;
}


/****************************************************************************
add an interface to the linked list of interfaces
****************************************************************************/
static void add_interface(TALLOC_CTX *mem_ctx, struct in_addr ip, struct in_addr nmask, struct interface **interfaces)
{
	struct interface *iface;
	struct in_addr bcast;

	if (iface_find(*interfaces, ip, false)) {
		OC_DEBUG(3, "not adding duplicate interface %s", inet_ntoa(ip));
		return;
	}

	iface = talloc(*interfaces == NULL ? mem_ctx : *interfaces, struct interface);
	if (iface == NULL) 
		return;
	
	OC_ZERO_STRUCTPN(iface);

	iface->ip = ip;
	iface->nmask = nmask;
	bcast.s_addr = MKBCADDR(iface->ip.s_addr, iface->nmask.s_addr);

	/* keep string versions too, to avoid people tripping over the implied
	   static in inet_ntoa() */
	iface->ip_s = talloc_strdup(iface, inet_ntoa(iface->ip));
	iface->nmask_s = talloc_strdup(iface, inet_ntoa(iface->nmask));
	
	if (nmask.s_addr != ~(in_addr_t)0) {
		iface->bcast_s = talloc_strdup(iface, inet_ntoa(bcast));
	}

	DLIST_ADD_END(*interfaces, iface, struct interface *);

	OC_DEBUG(2, "added interface ip=%s nmask=%s", iface->ip_s, iface->nmask_s);
}

enum oc_protocol_types {
	PROTOCOL_DEFAULT=-1,
	PROTOCOL_NONE=0,
	PROTOCOL_CORE,
	PROTOCOL_COREPLUS,
	PROTOCOL_LANMAN1,
	PROTOCOL_LANMAN2,
	PROTOCOL_NT1,
	PROTOCOL_SMB2_02,
	PROTOCOL_SMB2_10,
	PROTOCOL_SMB2_22,
	PROTOCOL_SMB2_24,
	PROTOCOL_SMB3_00,
	PROTOCOL_SMB3_02,
	PROTOCOL_SMB3_10,
	PROTOCOL_SMB3_11
};

static int openchange_null_match(const char *p)
{
	for (;*p;p++) {
		if (*p != '*' &&
		    *p != '<' &&
		    *p != '"' &&
		    *p != '>') return -1;
	}
	return 0;
}

/*
  the max_n structure is purely for efficiency, it doesn't contribute
  to the matching algorithm except by ensuring that the algorithm does
  not grow exponentially
*/
struct max_n {
	const char *predot;
	const char *postdot;
};


/*
  p and n are the pattern and string being matched. The max_n array is
  an optimisation only. The ldot pointer is NULL if the string does
  not contain a '.', otherwise it points at the last dot in 'n'.
*/
static int openchange_ms_fnmatch_core(const char *p, const char *n, 
				      struct max_n *max_n, const char *ldot)
{
	codepoint_t c, c2;
	int i;
	size_t size, size_n;

	while ((c = next_codepoint(p, &size))) {
		p += size;

		switch (c) {
		case '*':
			/* a '*' matches zero or more characters of any type */
			if (max_n->predot && max_n->predot <= n) {
				return openchange_null_match(p);
			}
			for (i=0; n[i]; i += size_n) {
				next_codepoint(n+i, &size_n);
				if (openchange_ms_fnmatch_core(p, n+i, max_n+1, ldot) == 0) {
					return 0;
				}
			}
			if (!max_n->predot || max_n->predot > n) max_n->predot = n;
			return openchange_null_match(p);

		case '<':
			/* a '<' matches zero or more characters of
			   any type, but stops matching at the last
			   '.' in the string. */
			if (max_n->predot && max_n->predot <= n) {
				return openchange_null_match(p);
			}
			if (max_n->postdot && max_n->postdot <= n && n <= ldot) {
				return -1;
			}
			for (i=0; n[i]; i += size_n) {
				next_codepoint(n+i, &size_n);
				if (openchange_ms_fnmatch_core(p, n+i, max_n+1, ldot) == 0) return 0;
				if (n+i == ldot) {
					if (openchange_ms_fnmatch_core(p, n+i+size_n, max_n+1, ldot) == 0) return 0;
					if (!max_n->postdot || max_n->postdot > n) max_n->postdot = n;
					return -1;
				}
			}
			if (!max_n->predot || max_n->predot > n) max_n->predot = n;
			return openchange_null_match(p);

		case '?':
			/* a '?' matches any single character */
			if (! *n) {
				return -1;
			}
			next_codepoint(n, &size_n);
			n += size_n;
			break;

		case '>':
			/* a '?' matches any single character, but
			   treats '.' specially */
			if (n[0] == '.') {
				if (! n[1] && openchange_null_match(p) == 0) {
					return 0;
				}
				break;
			}
			if (! *n) return openchange_null_match(p);
			next_codepoint(n, &size_n);
			n += size_n;
			break;

		case '"':
			/* a bit like a soft '.' */
			if (*n == 0 && openchange_null_match(p) == 0) {
				return 0;
			}
			if (*n != '.') return -1;
			next_codepoint(n, &size_n);
			n += size_n;
			break;

		default:
			c2 = next_codepoint(n, &size_n);
			if (c != c2 && codepoint_cmpi(c, c2) != 0) {
				return -1;
			}
			n += size_n;
			break;
		}
	}
	
	if (! *n) {
		return 0;
	}
	
	return -1;
}

static int openchange_ms_fnmatch_protocol(const char *pattern, const char *string, int protocol)
{
	int ret, count, i;
	struct max_n *max_n = NULL;

	if (strcmp(string, "..") == 0) {
		string = ".";
	}

	if (strpbrk(pattern, "<>*?\"") == NULL) {
		/* this is not just an optimisation - it is essential
		   for LANMAN1 correctness */
		return strcasecmp_m(pattern, string);
	}

	if (protocol <= PROTOCOL_LANMAN2) {
		char *p = talloc_strdup(NULL, pattern);
		if (p == NULL) {
			return -1;
		}
		/*
		  for older negotiated protocols it is possible to
		  translate the pattern to produce a "new style"
		  pattern that exactly matches w2k behaviour
		*/
		for (i=0;p[i];i++) {
			if (p[i] == '?') {
				p[i] = '>';
			} else if (p[i] == '.' && 
				   (p[i+1] == '?' || 
				    p[i+1] == '*' ||
				    p[i+1] == 0)) {
				p[i] = '"';
			} else if (p[i] == '*' && 
				   p[i+1] == '.') {
				p[i] = '<';
			}
		}
		ret = openchange_ms_fnmatch_protocol(p, string, PROTOCOL_NT1);
		talloc_free(p);
		return ret;
	}

	for (count=i=0;pattern[i];i++) {
		if (pattern[i] == '*' || pattern[i] == '<') count++;
	}

	max_n = talloc_zero_array(NULL, struct max_n, count);
	if (max_n == NULL) {
		return -1;
	}

	ret = openchange_ms_fnmatch_core(pattern, string, max_n, strrchr(string, '.'));

	talloc_free(max_n);

	return ret;
}


/** a generic fnmatch function - uses for non-CIFS pattern matching */
static int openchange_gen_fnmatch(const char *pattern, const char *string)
{
	return openchange_ms_fnmatch_protocol(pattern, string, PROTOCOL_NT1);
}


/**
interpret a single element from a interfaces= config line 

This handles the following different forms:

1) wildcard interface name
2) DNS name
3) IP/masklen
4) ip/mask
5) bcast/mask
**/
static void interpret_interface(TALLOC_CTX *mem_ctx, 
				const char *token, 
				struct iface_struct *probed_ifaces, 
				int total_probed,
				struct interface **local_interfaces)
{
	struct in_addr ip, nmask;
	char *p;
	char *address;
	int i, added=0;

	ip.s_addr = 0;
	nmask.s_addr = 0;
	
	/* first check if it is an interface name */
	for (i=0;i<total_probed;i++) {
		if (openchange_gen_fnmatch(token, probed_ifaces[i].name) == 0) {
			add_interface(mem_ctx, probed_ifaces[i].ip,
				      probed_ifaces[i].netmask,
				      local_interfaces);
			added = 1;
		}
	}
	if (added) return;

	/* maybe it is a DNS name */
	p = strchr_m(token,'/');
	if (!p) {
		/* don't try to do dns lookups on wildcard names */
		if (strpbrk(token, "*?") != NULL) {
			return;
		}
		ip.s_addr = interpret_addr2(token).s_addr;
		for (i=0;i<total_probed;i++) {
			if (ip.s_addr == probed_ifaces[i].ip.s_addr) {
				add_interface(mem_ctx, probed_ifaces[i].ip,
					      probed_ifaces[i].netmask,
					      local_interfaces);
				return;
			}
		}
		OC_DEBUG(2, "can't determine netmask for %s", token);
		return;
	}

	address = talloc_strdup(mem_ctx, token);
	p = strchr_m(address,'/');

	/* parse it into an IP address/netmasklength pair */
	*p++ = 0;

	ip.s_addr = interpret_addr2(address).s_addr;

	if (strlen(p) > 2) {
		nmask.s_addr = interpret_addr2(p).s_addr;
	} else {
		nmask.s_addr = htonl(((ALLONES >> atoi(p)) ^ ALLONES));
	}

	/* maybe the first component was a broadcast address */
	if (ip.s_addr == MKBCADDR(ip.s_addr, nmask.s_addr) ||
	    ip.s_addr == MKNETADDR(ip.s_addr, nmask.s_addr)) {
		for (i=0;i<total_probed;i++) {
			if (same_net_v4(ip, probed_ifaces[i].ip, nmask)) {
				add_interface(mem_ctx, probed_ifaces[i].ip, nmask,
					      local_interfaces);
				talloc_free(address);
				return;
			}
		}
		OC_DEBUG(2, "Can't determine ip for broadcast address %s", address);
		talloc_free(address);
		return;
	}

	add_interface(mem_ctx, ip, nmask, local_interfaces);
	talloc_free(address);
}


/**
load the list of network interfaces
**/
void openchange_load_interfaces(TALLOC_CTX *mem_ctx, const char **interfaces, struct interface **local_interfaces)
{
	const char **ptr = interfaces;
	int i;
	struct iface_struct ifaces[MAX_INTERFACES];
	struct in_addr loopback_ip;
	int total_probed;

	*local_interfaces = NULL;

	loopback_ip = interpret_addr2("127.0.0.1");

	/* probe the kernel for interfaces */
	total_probed = get_interfaces_oc(ifaces, MAX_INTERFACES);

	/* if we don't have a interfaces line then use all interfaces
	   except loopback */
	if (!ptr || !*ptr || !**ptr) {
		if (total_probed <= 0) {
			oc_log(OC_LOG_ERROR, "Could not determine network interfaces, you must use a interfaces config line");
		}
		for (i=0;i<total_probed;i++) {
			if (ifaces[i].ip.s_addr != loopback_ip.s_addr) {
				add_interface(mem_ctx, ifaces[i].ip, 
					      ifaces[i].netmask, local_interfaces);
			}
		}
	}

	while (ptr && *ptr) {
		interpret_interface(mem_ctx, *ptr, ifaces, total_probed, local_interfaces);
		ptr++;
	}

	if (!*local_interfaces) {
		oc_log(OC_LOG_WARNING, "no network interfaces found");
	}
}

/**
  how many interfaces do we have
  **/
int libmapi_iface_count(struct interface *ifaces)
{
	int ret = 0;
	struct interface *i;

	for (i=ifaces;i;i=i->next)
		ret++;
	return ret;
}

/**
  return IP of the Nth interface
  **/
const char *libmapi_iface_n_ip(struct interface *ifaces, int n)
{
	struct interface *i;
  
	for (i=ifaces;i && n;i=i->next)
		n--;

	if (i) {
		return i->ip_s;
	}
	return NULL;
}

/**
  return bcast of the Nth interface
  **/
const char *libmapi_iface_n_bcast(struct interface *ifaces, int n)
{
	struct interface *i;
  
	for (i=ifaces;i && n;i=i->next)
		n--;

	if (i) {
		return i->bcast_s;
	}
	return NULL;
}

/**
  return netmask of the Nth interface
  **/
const char *libmapi_iface_n_netmask(struct interface *ifaces, int n)
{
	struct interface *i;
  
	for (i=ifaces;i && n;i=i->next)
		n--;

	if (i) {
		return i->nmask_s;
	}
	return NULL;
}

/**
  return the local IP address that best matches a destination IP, or
  our first interface if none match
*/
const char *libmapi_iface_best_ip(struct interface *ifaces, const char *dest)
{
	struct interface *iface;
	struct in_addr ip;

	ip.s_addr = interpret_addr(dest);
	iface = iface_find(ifaces, ip, true);
	if (iface) {
		return iface->ip_s;
	}
	return libmapi_iface_n_ip(ifaces, 0);
}

/**
  return true if an IP is one one of our local networks
*/
bool libmapi_iface_is_local(struct interface *ifaces, const char *dest)
{
	struct in_addr ip;

	ip.s_addr = interpret_addr(dest);
	if (iface_find(ifaces, ip, true)) {
		return true;
	}
	return false;
}

/**
  return true if a IP matches a IP/netmask pair
*/
bool libmapi_iface_same_net(const char *ip1, const char *ip2, const char *netmask)
{
	return same_net_v4(interpret_addr2(ip1),
			interpret_addr2(ip2),
			interpret_addr2(netmask));
}
