char *loadlib_c = "$Id: loadlib.c,v 1.2 2001/01/11 17:59:38 ana Exp $";

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <lua.h>

#include <lauxlib.h>
#include <lualib.h>

#include "loadlib.h"

/*
  define one of these macros to define the dl interface: WIN32, SHL, RLD or DLFCN

  systems that support some kind of dynamic linking:
  WIN32: MS Windows 95/98/NT
  SHL: HP-UX
  RLD: NeXT
  DLFCN: Linux, SunOS, IRIX, UNIX_SV, OSF1, SCO_SV, BSD/OS
  DLFCN (simulation): AIX
*/

#if defined(WIN32)
#define MAP_FORMAT "%s%s.dll"
#else
#define MAP_FORMAT "%slib%s.so"
#endif


#if defined(WIN32)

#include <windows.h>
typedef HINSTANCE libtype;
typedef FARPROC   functype;
#define loadfunc(lib,name) GetProcAddress( lib, name )
#define unloadlibrary(lib) FreeLibrary( lib )
#define liberror()         dll_error("Could not load library.")
#define funcerror()        dll_error("Could not load function.")
 
static libtype loadlibrary (char* path)
{
  libtype libhandle = LoadLibrary(path);
  if (!libhandle) {
    int maxtries = 10;
    do {
      maxtries--;
      Sleep( 2 );
      libhandle = LoadLibrary( path );
    } while( !libhandle &&
             maxtries>0 &&
             GetLastError()==ERROR_SHARING_VIOLATION );
  }
  return libhandle;
}
 
#define BUFFER_SIZE 100
 
static char* dll_error (char* altmsg)
{
  static char buffer[BUFFER_SIZE+1];
  if ( FormatMessage(FORMAT_MESSAGE_IGNORE_INSERTS |
                     FORMAT_MESSAGE_FROM_SYSTEM,
                     0, /* source */
                     GetLastError(),
                     0, /* langid */
                     buffer,
                     BUFFER_SIZE,
                     0 /* arguments */) ) {
    return buffer;
  }
  else {
    return altmsg;
  }
}
 
#elif defined(DLFCN)

#include <dlfcn.h>
#ifndef RTLD_GLOBAL
#define RTLD_GLOBAL 0
#endif
typedef void* libtype;
typedef lua_CFunction functype;
#define loadlibrary(path)  dlopen( path, RTLD_LAZY | RTLD_GLOBAL )
#define loadfunc(lib,name) (functype)dlsym( lib, name )
#define unloadlibrary(lib) dlclose( lib )
#define liberror()         dlerror()
#define funcerror()        dlerror()

#elif defined(SHL)

#include <dl.h>
typedef shl_t libtype;
typedef lua_CFunction functype;
#define loadlibrary(path)  shl_load( path, BIND_DEFERRED | BIND_NOSTART, 0L )
#define unloadlibrary(lib) shl_unload( lib )
#define liberror()         "Could not load library."
#define funcerror()        "Could not load function."

static functype loadfunc (libtype lib, char *name)
{
  functype fn;
  if (shl_findsym( &lib, name, TYPE_PROCEDURE, &fn ) == -1)
    return 0;
  return fn;
}

#elif defined(RLD)

#include <rld.h>
typedef long libtype;
typedef lua_CFunction functype;
#define loadlibrary(path)  rldload( 0, 0, path, 0  )
#define unloadlibrary(lib) ;
#define liberror()         "Could not load library."
#define funcerror()        "Could not load function."

static functype loadfunc (libtype lib, char *name)
{
  functype fn;
  char* _name = (char*)malloc((strlen(name)+2)*sizeof(char));
  if (!_name) return 0;
  _name[0] = '_';
  strcpy( _name+1, name );
  if (!rld_lookup( 0, _name, &fn )) {
    free(_name);
    return 0;
  }
  free(_name);
  return fn;
}

#else

typedef void* libtype;
typedef lua_CFunction functype;
#define loadlibrary(path)  (0)
#define loadfunc(lib,name) (0)
#define unloadlibrary(lib) ;
#define liberror()         "Dynamic libraries not supported."
#define funcerror()        ""

#endif

#define LIBTAG      -2
#define UNLOADEDTAG -1

static int gettag (lua_State *L, int i)
{
  return (int)lua_tonumber(L, i);
}


static libtype check_libhandle (lua_State *L, int nparam)
{
  luaL_arg_check(L, lua_isuserdata(L, nparam), nparam, "userdata expected");
  luaL_arg_check(L, lua_tag(L, nparam) == gettag(L, LIBTAG),
                  nparam, "not a valid library handle" );
  return (libtype)lua_touserdata(L, nparam);
}

static int loadlib (lua_State *L)
{
  const char *libname;
  char *path;
  libtype lib;

  int tag = gettag( L, LIBTAG );
  lua_pop(L,2);  /* remove upvalues from stack */

  libname = luaL_check_string(L, 1);
  if (strpbrk(libname, ".:/\\"))
    path = libname;
  else {
    const char *dir = luaL_opt_string(L, 2, "");
    path = (char*)malloc(sizeof(char)*(strlen(dir) +
                                       strlen(libname) +
                                       strlen(MAP_FORMAT) + 1));
    if (!path) lua_error(L, "not enough memory.");
    sprintf(path, MAP_FORMAT, dir, libname);
  }

  lib = loadlibrary(path);
  if (path != libname) free(path);
  if ( !lib ) {
    lua_pushnil(L);
    lua_pushstring(L, liberror());
    return 2;
  }
  lua_pushusertag(L, lib, tag);
  return 1;
}

static int callfromlib (lua_State *L)
{
  libtype lh = check_libhandle(L, 1);
  const char *funcname = luaL_check_string(L, 2);
  functype fn = loadfunc(lh, funcname);
  if (fn) {
    fn(L);
    return 0;
  }
  else
    lua_error(L, funcerror());
}

static int unloadlib (lua_State *L)
{
  int unloadedtag = gettag(L,UNLOADEDTAG);
  unloadlibrary(check_libhandle(L, 1));
  lua_pushvalue(L, 1);
  lua_settag(L,unloadedtag);
  return 0;
}

void loadlib_open( lua_State *L )
{
  static struct luaL_reg funcs[] = {
    {"loadlib",     loadlib},
    {"unloadlib",   unloadlib},
    {"callfromlib", callfromlib}
  };
  int libtag, unloadedtag;
  int i;
  libtag = lua_newtag(L);
  unloadedtag = lua_newtag(L);
  for (i=0; i<sizeof(funcs)/sizeof(funcs[0]); i++)
  {
    lua_pushnumber(L, libtag);
    lua_pushnumber(L, unloadedtag);
    lua_pushcclosure(L, funcs[i].func, 2);
    lua_setglobal(L, funcs[i].name);
  }
  lua_pushstring(L, LOADLIB_VERSION);
  lua_setglobal(L, "LOADLIB_VERSION");
}
