ThreadLocal

Mr.ZhangJava大约 8 分钟

ThreadLocal

提示

对于无锁编程,我们已经讲解了CAS、原子类、累加器,本节,我们讲解无锁编程中的最后一个知识点:ThreadLocal。我们知道,共享变量是代码存在线程安全的根本原因之一。在某些特殊的业务场景下,我们可以使用ThreadLocal线程局部变量替代共享变量,以实现在不需要加锁的情况下达到线程安全。

一、基本用法

在Java中,我们可以将变量粗略的划分为两类:类的成员变量和函数内局部变量。对于类的成员变量,当多个线程使用同一个对象时,对象中的成员变量就是共享变量,其作用域范围为多个线程均可见。多个线程竞态访问成员变量,就有可能存在线程安全问题。对于函数内局部变量,每个线程在执行函数时,会在自己的栈上创建私有的局部变量,因此,函数内局部变量的作用域范围为单线程内可见。不仅如此,函数内局部变量仅限函数内可见,不同的函数之间不可以共享局部变量。多个函数共享局部变量,需要通过参数传递的方式来实现。我们举例解释一下。

假设我们现在有这样一个需求:在一个标准的Controller-Service-Repository三层结构的后端系统中,我们希望实现一个简单的调用链追踪功能。每个接口请求所对应的所有日志都附带同一个traceId,这样,我们通过traceId便可以轻松得到一个接口请求的所有日志,方便通过日志查找代码问题。以下是调用链追踪功能的一种实现方式。traceId定义为函数内的局部变量,需要通过参数传递给被调用函数使用。

public class UserController {
  private static final Logger logger = 
      LoggerFactory.getLogger(UserController.class);
  private UserService userService = new UserService();

  public long login(String username, String password) {
    // 创建traceId
    String traceId = "[" + System.currentTimeMillis() + "]";
    // 所有的日志都带有traceId
    logger.trace(traceId + " username=" + username);
    //...省略校验逻辑...
    return userService.login(username, password, traceId);//传递traceId
  }
}

上述代码存在的问题很明显,我们需要在每个函数中都定义traceId参数,导致非业务代码和业务代码耦合在一起。为了解决这个问题,JUC便提供了ThreadLocal,其作用域范围介于类的成员变量和函数内局部变量之间,它既是线程私有的,又可以在函数之间共享,这样既可以避免线程安全问题,又能避免变量在函数之间不停传递。

ThreadLocal的提供的函数如下所示。其中,set()函数用来将变量值存储在当前线程中,get()函数用来从当前线程中取出变量值。remove()函数用来从当前线程删除变量。initialValue()函数是一个protected方法,可以在子类中重新实现,用于提供变量的初始值。

public class ThreadLocal<T> {
    protected T initialValue() { return null; }
    public T get();
    public void set(T value);
    public void remove();
}

我们使用ThreadLocal重新实现调用链追踪功能,对应的代码如下所示。我们实现了一个匿名类,继承自ThreadLocal,重写了initialValue()来提供threadLocalTraceId的初始值。如果在调用get()函数之前没有调用set()函数设置threadLocalTraceId的值,ThreadLocal会调用我们提供的initialValue()函数,使用其返回值初始化threadLocalTraceId。对于详细的处理逻辑,我们待会结合源码来讲解。

public class Context {
  private static final ThreadLocal<String> threadLocalTraceId = 
      new ThreadLocal<String>() {
    @Override
    protected String initialValue() {
      return "[" + System.currentTimeMillis() + "]";
    }
  };

  public static void setTraceId(String traceId) {
    threadLocalTraceId.set(traceId);
  }

  public static String getTraceId() {
    return threadLocalTraceId.get();
  }

  public static void remove() {
    threadLocalTraceId.remove();
  }
}

public class UserController {
  private static final Logger logger = 
      LoggerFactory.getLogger(UserController.class);
  private UserService userService = new UserService();

  public long login(String username, String password) {
    // 所有的日志都带有traceId
    logger.trace(Context.getTraceId() + " username=" + username);
    //...省略校验逻辑...
    return userService.login(username, password); //通过Context传递traceId
  }
}

二、实现原理

很多人对ThreadLocal的用法都感到奇怪,往往会有这样的疑问:使用ThreadLocal定义的一个变量怎么存储多个线程的数据?或者说在类中的定义的ThreadLocal变量如何“分身”到多个线程中使用的?有这样疑问的主要原因是对ThreadLocal的底层实现原理没有了解。为了回答以上疑问,接下来,我们就结合源码来了解一下ThreadLocal的底层实现原理。

ThreadLocal的代码结构如下所示。从如下代码中,我们可以发现,ThreadLocal类只定义了读写数据的方法,并没有定义任何成员变量来存储数据。那么,set()函数写入的数据存储在哪里呢?get()函数又是从哪里读取数据的呢?

public class ThreadLocal<T> {
    public ThreadLocal() {}
    protected T initialValue() { return null; }
    public T get() { ... }    
    public void set(T value) { ... }
    public void remove() { ... }
    
    public static class ThreadLocalMap {
      ...
    }
}

public class Thread implements Runnable {
    ...
    ThreadLocal.ThreadLocalMap threadLocals = null;
    ...
}

要回答以上问题,我们需要了解ThreadLocal的数据存储结构,如下图所示。从图中,我们可以发现,实际上,数据是存储在线程对应的Thread对象的threadLocals成员变量中的。threadLocals成员变量的类型为ThreadLocalMap。ThreadLocalMap是ThreadLocal的内部类。ThreadLocalMap类似HashMap,也是用来存储键值对,其中,键为ThreadLocal对象,值为Object对象。

img
img

接下来,我们依次讲解一下ThreadLocal中的set()、get()、remove()三个函数的底层实现原理。

1)set()函数的底层实现原理

set()函数的源码如下所示。对ThreadLocal的存储结构有了了解之后,我们便很容易看懂set()函数的代码逻辑。对于set()函数的代码逻辑,我在代码中添加了详细的注释,这里就不再赘述了。

public void set(T value) {
    Thread t = Thread.currentThread(); //获取当前线程对应的Thread对象
    ThreadLocalMap map = getMap(t); //获取Thread对象的threadLocals成员变量
    if (map != null) map.set(this, value); //threadLocals不为空,则添加键值对
    else createMap(t, value); //threadLocals为空,则先创建再添加
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

2)get()函数的底层实现原理

get()函数的源码如下所示。get()函数的逻辑也比较简单。我们先获取当前线程的threadLocals成员变量。如果threadLocals不为null,那么,我们在threadLocals中查找当前操作的ThreadLocal变量对应的数据值。如果查找到对应的数据值,则直接返回。如果threadLocals为null,或者没有查找threadLocal变量对应的数据值,则调用initialValue()方法获取到threadLocal变量的初始值,创建threadLocals并添加键值对(threadLocal变量和初始值)。

public T get() {
    Thread t = Thread.currentThread(); //获取当前线程对应的Thread对象
    ThreadLocalMap map = getMap(t); //获取Thread对象的threadLocals成员变量
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this); //this是ThreadLocal变量
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result; //获取到对应的数据值
        }
    }
    //要么map为null,要么没有获取到对应的数据值,执行初始化操作
    return setInitialValue();
}

private T setInitialValue() {
    T value = initialValue(); //默认返回null,但可重新实现此函数,见上述Context示例
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) map.set(this, value); //添加键值对
    else createMap(t, value); //创建threadLocals
    return value;
}

3)remove()函数的底层实现原理

remove()函数的源码如下所示,相对来说,更加简单。对其代码逻辑,我们就不再赘述了。

 public void remove() {
     ThreadLocalMap m = getMap(Thread.currentThread());
     if (m != null) m.remove(this);
 }

三、应用场景

前面我们讲到的ReentrantReadWriteLock,其底层实现也用到了ThreadLocal。我们一块回忆一下。对于ReentrantReadWriteLock,AQS中的state同时存储写锁和读锁的加锁情况。state的低16位存储写锁的加锁情况,值为0表示没有加写锁,值为1表示已加写锁,值大于1表示写锁的重入次数。state的高16位存储读锁的加锁情况,值为0表示没有加读锁,值为1表示已加读锁,不过,值大于1并不表示读锁的重入次数,而是表示读锁总共被获取了多少次(每个线程对读锁重入的次数相加),此值用来最终解锁读锁。而每个线程对读锁的重入次数是有用信息,只有重入次数大于0时,线程才可以继续重入。那么,重入次数在哪里记录呢?因为重入次数是跟每个线程相关的数据,所以,我们就可以使用ThreadLocal变量来存储它。对应的源码如下所示。

//以下代码位于ReentrantReadWriteLock.java中
static final class HoldCounter {
    int count = 0;
    final long tid = getThreadId(Thread.currentThread());
}

static final class ThreadLocalHoldCounter extends ThreadLocal<HoldCounter> {
    public HoldCounter initialValue() {
        return new HoldCounter();
    }
}

private transient ThreadLocalHoldCounter readHolds;

四、课后思考题

在本节的示例中,我们使用时间戳来生成traceId,如果我们更改traceId的生成方式,使用自增的方式生成traceId,如下代码所示,那么,请问下面的代码是否是线程安全的?如果不是,应该如何修改才能保证其线程安全性?

public class Context {
  private static int id = 0;
  private static final ThreadLocal<String> threadLocalTraceId = 
      new ThreadLocal<String>() {
    @Override
    protected String initialValue() {
      id++;
      return String.valueOf(id);    }
  };

  public static void setTraceId(String traceId) {
    threadLocalTraceId.set(traceId);
  }

  public static String getTraceId() {
    return threadLocalTraceId.get();
  }

  public static void remove() {
    threadLocalTraceId.remove();
  }
}
Loading...